Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/cast_to_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::sync::Arc;
use arrow::array::{Array, ArrayRef, AsArray, StructArray};
use arrow_schema::{DataType, Field};
use datafusion::{
common::exec_err,
error::Result,
logical_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Expand All @@ -14,7 +13,9 @@ use datafusion::{
use parquet_variant::Variant;
use parquet_variant_compute::{VariantArray, VariantArrayBuilder, cast_to_variant};

use crate::shared::{try_parse_binary_columnar, try_parse_binary_scalar};
use crate::shared::{
arg_shape_err, args_count_err, try_parse_binary_columnar, try_parse_binary_scalar,
};

#[derive(Debug, Hash, PartialEq, Eq)]
pub struct CastToVariantUdf {
Expand Down Expand Up @@ -57,15 +58,19 @@ impl CastToVariantUdf {
}

fn from_metadata_value(
udf_name: &str,
metadata_argument: &ColumnarValue,
variant_argument: &ColumnarValue,
) -> Result<ColumnarValue> {
let out = match (metadata_argument, variant_argument) {
(ColumnarValue::Array(metadata_array), ColumnarValue::Array(value_array)) => {
if metadata_array.len() != value_array.len() {
return exec_err!(
"expected metadata array to be of same length as variant array"
);
return Err(arg_shape_err(
udf_name,
2,
"array with same length as arg #1",
"array with different length",
));
}

let metadata_array = try_parse_binary_columnar(metadata_array)?;
Expand Down Expand Up @@ -180,11 +185,11 @@ impl ScalarUDFImpl for CastToVariantUdf {
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
match args.args.as_slice() {
[metadata_value, variant_value] => {
Self::from_metadata_value(metadata_value, variant_value)
Self::from_metadata_value(self.name(), metadata_value, variant_value)
}
[ColumnarValue::Scalar(scalar_value)] => Self::from_scalar_value(scalar_value),
[ColumnarValue::Array(array)] => Self::from_array(array),
_ => exec_err!("unrecognized argument"),
_ => Err(args_count_err(self.name(), "1 or 2", args.args.len())),
}
}
}
Expand Down
10 changes: 6 additions & 4 deletions src/is_variant_null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::sync::Arc;

use arrow::array::{ArrayRef, BooleanArray};
use arrow_schema::DataType;
use datafusion::common::{exec_datafusion_err, exec_err};
use datafusion::error::Result;
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
Expand All @@ -11,7 +10,10 @@ use datafusion::scalar::ScalarValue;
use parquet_variant::Variant;
use parquet_variant_compute::VariantArray;

use crate::shared::{try_field_as_variant_array, try_parse_variant_scalar};
use crate::shared::{
arg_field_meta_missing_err, args_count_err, try_field_as_variant_array,
try_parse_variant_scalar,
};

#[derive(Debug, Hash, PartialEq, Eq)]
pub struct IsVariantNullUdf {
Expand Down Expand Up @@ -47,12 +49,12 @@ impl ScalarUDFImpl for IsVariantNullUdf {
let variant_field = args
.arg_fields
.first()
.ok_or_else(|| exec_datafusion_err!("expected 1 argument field type"))?;
.ok_or_else(|| arg_field_meta_missing_err(self.name(), 1))?;

try_field_as_variant_array(variant_field.as_ref())?;

let [variant_arg] = args.args.as_slice() else {
return exec_err!("expected 1 argument");
return Err(args_count_err(self.name(), "1", args.args.len()));
};

let out = match variant_arg {
Expand Down
17 changes: 12 additions & 5 deletions src/json_to_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::sync::Arc;
use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray, StructArray};
use arrow_schema::{DataType, Field, Fields};
use datafusion::{
common::{exec_datafusion_err, exec_err},
common::exec_datafusion_err,
error::Result,
logical_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
Expand All @@ -15,7 +15,7 @@ use datafusion::{
use parquet_variant_compute::{VariantArrayBuilder, VariantType};
use parquet_variant_json::JsonToVariant as JsonToVariantExt;

use crate::shared::{try_field_as_string, try_parse_string_scalar};
use crate::shared::{arg_type_err, args_count_err, try_field_as_string, try_parse_string_scalar};

/// Returns a Variant from a JSON string
#[derive(Debug, Hash, PartialEq, Eq)]
Expand Down Expand Up @@ -74,14 +74,14 @@ impl ScalarUDFImpl for JsonToVariantUdf {
let arg_field = args
.arg_fields
.first()
.ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?;
.ok_or_else(|| args_count_err(self.name(), "1", args.arg_fields.len()))?;

try_field_as_string(arg_field.as_ref())?;

let arg = args
.args
.first()
.ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?;
.ok_or_else(|| args_count_err(self.name(), "1", args.args.len()))?;

let out = match arg {
ColumnarValue::Scalar(scalar_value) => {
Expand All @@ -101,7 +101,14 @@ impl ScalarUDFImpl for JsonToVariantUdf {
DataType::Utf8 => ColumnarValue::Array(from_utf8_arr(arr)?),
DataType::LargeUtf8 => ColumnarValue::Array(from_large_utf8_arr(arr)?),
DataType::Utf8View => ColumnarValue::Array(from_utf8view_arr(arr)?),
_ => return exec_err!("Invalid data type {}", arr.data_type()),
_ => {
return arg_type_err(
self.name(),
1,
"Utf8, LargeUtf8, or Utf8View",
arr.data_type(),
);
}
},
};

Expand Down
56 changes: 55 additions & 1 deletion src/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use arrow_schema::Fields;
use arrow_schema::extension::ExtensionType;
use arrow_schema::{DataType, Field};
use datafusion::common::exec_datafusion_err;
use datafusion::error::Result;
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs};
use datafusion::{common::exec_err, scalar::ScalarValue};
use parquet_variant::Variant;
Expand Down Expand Up @@ -262,6 +262,60 @@ pub fn ensure(pred: bool, err_msg: &str) -> Result<()> {
Ok(())
}

/// Helper for argument count errors.
pub fn args_count_err(udf: &str, expected: &'static str, actual: usize) -> DataFusionError {
DataFusionError::Execution(format!(
"{udf}: expected {expected} argument(s), got {actual}"
))
}

/// Helper for argument type errors.
pub fn arg_type_err<T>(
udf: &str,
arg_index: u8,
expected: &str,
actual: &DataType,
) -> Result<T, DataFusionError> {
Err(DataFusionError::Execution(format!(
"{udf} arg #{arg_index}: expected {expected}, got {actual}"
)))
}

/// Helper for unexpected NULL argument values.
pub fn arg_null_err<T>(udf: &str, arg_index: u8, expected: &str) -> Result<T, DataFusionError> {
Err(arg_null_error(udf, arg_index, expected))
}

/// Helper for unexpected NULL argument values as a plain DataFusionError.
pub fn arg_null_error(udf: &str, arg_index: u8, expected: &str) -> DataFusionError {
DataFusionError::Execution(format!(
"{udf} arg #{arg_index}: expected {expected}, got NULL"
))
}

/// Helper for scalar/array shape mismatches.
pub fn arg_shape_err(udf: &str, arg_index: u8, expected: &str, actual: &str) -> DataFusionError {
DataFusionError::Execution(format!(
"{udf} arg #{arg_index}: expected {expected}, got {actual}"
))
}

/// Helper for invalid Variant kind in an argument.
pub fn arg_variant_kind_err(
udf: &str,
arg_index: u8,
expected_variant_kind: &str,
) -> DataFusionError {
DataFusionError::Execution(format!(
"{udf} arg #{arg_index}: expected variant {expected_variant_kind}"
))
}

/// Helper for missing argument field metadata.
pub fn arg_field_meta_missing_err(udf: &str, arg_index: u8) -> DataFusionError {
DataFusionError::Execution(format!("{udf} arg #{arg_index} field metadata is missing"))
}

// test related methods

#[cfg(test)]
Expand Down
48 changes: 31 additions & 17 deletions src/variant_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use arrow::{
};
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields};
use datafusion::{
common::{arrow_datafusion_err, exec_datafusion_err, exec_err},
common::{arrow_datafusion_err, exec_datafusion_err},
error::{DataFusionError, Result},
logical_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Expand All @@ -22,19 +22,29 @@ use parquet_variant_json::VariantToJson;

use crate::impl_variant_get::impl_variant_get_typed;
use crate::shared::{
arg_field_meta_missing_err, arg_null_err, arg_shape_err, arg_type_err, args_count_err,
invoke_variant_get_typed, try_field_as_variant_array, try_parse_string_columnar,
try_parse_string_scalar,
};

fn type_hint_from_scalar(field_name: &str, scalar: &ScalarValue) -> Result<FieldRef> {
fn type_hint_from_scalar(
udf_name: &str,
field_name: &str,
scalar: &ScalarValue,
) -> Result<FieldRef> {
let type_name = match scalar {
ScalarValue::Utf8(Some(value))
| ScalarValue::Utf8View(Some(value))
| ScalarValue::LargeUtf8(Some(value)) => value.as_str(),
ScalarValue::Utf8(None) | ScalarValue::Utf8View(None) | ScalarValue::LargeUtf8(None) => {
return arg_null_err(udf_name, 3, "a non-null UTF8 literal");
}
other => {
return exec_err!(
"type hint must be a non-null UTF8 literal, got {}",
other.data_type()
return arg_type_err(
udf_name,
3,
"Utf8, LargeUtf8, or Utf8View",
&other.data_type(),
);
}
};
Expand All @@ -48,12 +58,10 @@ fn type_hint_from_scalar(field_name: &str, scalar: &ScalarValue) -> Result<Field
Ok(Arc::new(Field::new(field_name, casted_type, true)))
}

fn type_hint_from_value(field_name: &str, arg: &ColumnarValue) -> Result<FieldRef> {
fn type_hint_from_value(udf_name: &str, field_name: &str, arg: &ColumnarValue) -> Result<FieldRef> {
match arg {
ColumnarValue::Scalar(value) => type_hint_from_scalar(field_name, value),
ColumnarValue::Array(_) => {
exec_err!("type hint argument must be a scalar UTF8 literal")
}
ColumnarValue::Scalar(value) => type_hint_from_scalar(udf_name, field_name, value),
ColumnarValue::Array(_) => Err(arg_shape_err(udf_name, 3, "scalar value", "array value")),
}
}

Expand Down Expand Up @@ -96,18 +104,19 @@ fn invoke_variant_get(
let (variant_arg, variant_path, type_arg) = match args.args.as_slice() {
[variant_arg, variant_path] => (variant_arg, variant_path, None),
[variant_arg, variant_path, type_arg] => (variant_arg, variant_path, Some(type_arg)),
_ => return exec_err!("expected 2 or 3 arguments"),
_ => return Err(args_count_err(udf_name, "2 or 3", args.args.len())),
};

let variant_field = args
.arg_fields
.first()
.ok_or_else(|| exec_datafusion_err!("expected argument field"))?;
.ok_or_else(|| arg_field_meta_missing_err(udf_name, 1))?;

try_field_as_variant_array(variant_field.as_ref())?;

let type_field_name = args.return_field.name();
let type_field = type_arg
.map(|arg| type_hint_from_value(udf_name, arg))
.map(|arg| type_hint_from_value(udf_name, type_field_name, arg))
.transpose()?;

let out = match (variant_arg, variant_path) {
Expand All @@ -125,7 +134,7 @@ fn invoke_variant_get(
}
(ColumnarValue::Scalar(scalar_variant), ColumnarValue::Scalar(variant_path)) => {
let ScalarValue::Struct(variant_array) = scalar_variant else {
return exec_err!("expected struct array");
return arg_type_err(udf_name, 1, "Struct", &scalar_variant.data_type());
};

let variant_array = Arc::clone(variant_array) as ArrayRef;
Expand All @@ -144,7 +153,12 @@ fn invoke_variant_get(
}
(ColumnarValue::Array(variant_array), ColumnarValue::Array(variant_paths)) => {
if variant_array.len() != variant_paths.len() {
return exec_err!("expected variant_array and variant paths to be of same length");
return Err(arg_shape_err(
udf_name,
2,
"array with same length as arg #1",
"array with different length",
));
}

let variant_paths = try_parse_string_columnar(variant_paths)?;
Expand Down Expand Up @@ -175,7 +189,7 @@ fn invoke_variant_get(
}
(ColumnarValue::Scalar(scalar_variant), ColumnarValue::Array(variant_paths)) => {
let ScalarValue::Struct(variant_array) = scalar_variant else {
return exec_err!("expected struct array");
return arg_type_err(udf_name, 1, "Struct", &scalar_variant.data_type());
};

let variant_array = Arc::clone(variant_array) as ArrayRef;
Expand Down Expand Up @@ -206,7 +220,7 @@ fn return_field_for_variant_get(name: &str, args: ReturnFieldArgs) -> Result<Arc
let scalar = maybe_scalar.ok_or_else(|| {
exec_datafusion_err!("type hint argument to {name} must be a literal")
})?;
return type_hint_from_scalar(name, scalar);
return type_hint_from_scalar(name, name, scalar);
}

let data_type = DataType::Struct(Fields::from(vec![
Expand Down
15 changes: 11 additions & 4 deletions src/variant_list_delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use datafusion::{
use parquet_variant::{Variant, VariantBuilder};
use parquet_variant_compute::{VariantArray, VariantType};

use crate::shared::{ensure, try_parse_variant_scalar};
use crate::shared::{
arg_shape_err, arg_variant_kind_err, args_count_err, ensure, try_parse_variant_scalar,
};

#[derive(Debug, Hash, PartialEq, Eq)]
pub struct VariantListDelete {
Expand Down Expand Up @@ -41,7 +43,7 @@ fn try_parse_index_scalar(scalar: &ScalarValue) -> Result<usize> {

fn delete_list_element(variant_list: Variant, index: usize) -> Result<(Vec<u8>, Vec<u8>)> {
let Variant::List(variant_list) = variant_list else {
return exec_err!("expected variant list");
return Err(arg_variant_kind_err("variant_list_delete", 1, "list"));
};

if index >= variant_list.len() {
Expand Down Expand Up @@ -109,7 +111,7 @@ impl ScalarUDFImpl for VariantListDelete {
)?;

let [variant_list_to_update, index_to_delete] = argument_values.as_slice() else {
return exec_err!("expected 2 arguments");
return Err(args_count_err(self.name(), "2", argument_values.len()));
};

ensure(
Expand All @@ -119,7 +121,12 @@ impl ScalarUDFImpl for VariantListDelete {

let index = {
let ColumnarValue::Scalar(index) = index_to_delete else {
return exec_err!("expected scalar value for index");
return Err(arg_shape_err(
self.name(),
2,
"scalar integer value",
"array value",
));
};

try_parse_index_scalar(index)?
Expand Down
Loading