Skip to content

Commit 34bbe78

Browse files
authored
feat: Cast numeric (non int) to timestamp (#3559)
* float_to_timestamp * non_numeric_to_timestamp
1 parent bd42649 commit 34bbe78

File tree

8 files changed

+531
-48
lines changed

8 files changed

+531
-48
lines changed

native/spark-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,7 @@ path = "tests/spark_expr_reg.rs"
103103
[[bench]]
104104
name = "cast_from_boolean"
105105
harness = false
106+
107+
[[bench]]
108+
name = "cast_non_int_numeric_timestamp"
109+
harness = false
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::builder::{BooleanBuilder, Decimal128Builder, Float32Builder, Float64Builder};
19+
use arrow::array::RecordBatch;
20+
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
21+
use criterion::{criterion_group, criterion_main, Criterion};
22+
use datafusion::physical_expr::{expressions::Column, PhysicalExpr};
23+
use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
24+
use std::sync::Arc;
25+
26+
const BATCH_SIZE: usize = 8192;
27+
28+
fn criterion_benchmark(c: &mut Criterion) {
29+
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
30+
let timestamp_type = DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()));
31+
32+
let mut group = c.benchmark_group("cast_non_int_numeric_to_timestamp");
33+
34+
// Float32 -> Timestamp
35+
let batch_f32 = create_float32_batch();
36+
let expr_f32 = Arc::new(Column::new("a", 0));
37+
let cast_f32_to_ts = Cast::new(expr_f32, timestamp_type.clone(), spark_cast_options.clone());
38+
group.bench_function("cast_f32_to_timestamp", |b| {
39+
b.iter(|| cast_f32_to_ts.evaluate(&batch_f32).unwrap());
40+
});
41+
42+
// Float64 -> Timestamp
43+
let batch_f64 = create_float64_batch();
44+
let expr_f64 = Arc::new(Column::new("a", 0));
45+
let cast_f64_to_ts = Cast::new(expr_f64, timestamp_type.clone(), spark_cast_options.clone());
46+
group.bench_function("cast_f64_to_timestamp", |b| {
47+
b.iter(|| cast_f64_to_ts.evaluate(&batch_f64).unwrap());
48+
});
49+
50+
// Boolean -> Timestamp
51+
let batch_bool = create_boolean_batch();
52+
let expr_bool = Arc::new(Column::new("a", 0));
53+
let cast_bool_to_ts = Cast::new(
54+
expr_bool,
55+
timestamp_type.clone(),
56+
spark_cast_options.clone(),
57+
);
58+
group.bench_function("cast_bool_to_timestamp", |b| {
59+
b.iter(|| cast_bool_to_ts.evaluate(&batch_bool).unwrap());
60+
});
61+
62+
// Decimal128 -> Timestamp
63+
let batch_decimal = create_decimal128_batch();
64+
let expr_decimal = Arc::new(Column::new("a", 0));
65+
let cast_decimal_to_ts = Cast::new(
66+
expr_decimal,
67+
timestamp_type.clone(),
68+
spark_cast_options.clone(),
69+
);
70+
group.bench_function("cast_decimal_to_timestamp", |b| {
71+
b.iter(|| cast_decimal_to_ts.evaluate(&batch_decimal).unwrap());
72+
});
73+
74+
group.finish();
75+
}
76+
77+
fn create_float32_batch() -> RecordBatch {
78+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
79+
let mut b = Float32Builder::with_capacity(BATCH_SIZE);
80+
for i in 0..BATCH_SIZE {
81+
if i % 10 == 0 {
82+
b.append_null();
83+
} else {
84+
b.append_value(rand::random::<f32>());
85+
}
86+
}
87+
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
88+
}
89+
90+
fn create_float64_batch() -> RecordBatch {
91+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
92+
let mut b = Float64Builder::with_capacity(BATCH_SIZE);
93+
for i in 0..BATCH_SIZE {
94+
if i % 10 == 0 {
95+
b.append_null();
96+
} else {
97+
b.append_value(rand::random::<f64>());
98+
}
99+
}
100+
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
101+
}
102+
103+
fn create_boolean_batch() -> RecordBatch {
104+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)]));
105+
let mut b = BooleanBuilder::with_capacity(BATCH_SIZE);
106+
for i in 0..BATCH_SIZE {
107+
if i % 10 == 0 {
108+
b.append_null();
109+
} else {
110+
b.append_value(rand::random::<bool>());
111+
}
112+
}
113+
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
114+
}
115+
116+
fn create_decimal128_batch() -> RecordBatch {
117+
let schema = Arc::new(Schema::new(vec![Field::new(
118+
"a",
119+
DataType::Decimal128(18, 6),
120+
true,
121+
)]));
122+
let mut b = Decimal128Builder::with_capacity(BATCH_SIZE);
123+
for i in 0..BATCH_SIZE {
124+
if i % 10 == 0 {
125+
b.append_null();
126+
} else {
127+
b.append_value(i as i128 * 1_000_000);
128+
}
129+
}
130+
let array = b.finish().with_precision_and_scale(18, 6).unwrap();
131+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
132+
}
133+
134+
fn config() -> Criterion {
135+
Criterion::default()
136+
}
137+
138+
criterion_group! {
139+
name = benches;
140+
config = config();
141+
targets = criterion_benchmark
142+
}
143+
criterion_main!(benches);

native/spark-expr/src/conversion_funcs/boolean.rs

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use crate::SparkResult;
19-
use arrow::array::{ArrayRef, AsArray, Decimal128Array};
19+
use arrow::array::{Array, ArrayRef, AsArray, Decimal128Array, TimestampMicrosecondBuilder};
2020
use arrow::datatypes::DataType;
2121
use std::sync::Arc;
2222

@@ -28,7 +28,6 @@ pub fn is_df_cast_from_bool_spark_compatible(to_type: &DataType) -> bool {
2828
)
2929
}
3030

31-
// only DF incompatible boolean cast
3231
pub fn cast_boolean_to_decimal(
3332
array: &ArrayRef,
3433
precision: u8,
@@ -43,6 +42,25 @@ pub fn cast_boolean_to_decimal(
4342
Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
4443
}
4544

45+
pub(crate) fn cast_boolean_to_timestamp(
46+
array_ref: &ArrayRef,
47+
target_tz: &Option<Arc<str>>,
48+
) -> SparkResult<ArrayRef> {
49+
let bool_array = array_ref.as_boolean();
50+
let mut builder = TimestampMicrosecondBuilder::with_capacity(bool_array.len());
51+
52+
for i in 0..bool_array.len() {
53+
if bool_array.is_null(i) {
54+
builder.append_null();
55+
} else {
56+
let micros = if bool_array.value(i) { 1 } else { 0 };
57+
builder.append_value(micros);
58+
}
59+
}
60+
61+
Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef)
62+
}
63+
4664
#[cfg(test)]
4765
mod tests {
4866
use super::*;
@@ -53,6 +71,7 @@ mod tests {
5371
Int64Array, Int8Array, StringArray,
5472
};
5573
use arrow::datatypes::DataType::Decimal128;
74+
use arrow::datatypes::TimestampMicrosecondType;
5675
use std::sync::Arc;
5776

5877
fn test_input_bool_array() -> ArrayRef {
@@ -193,4 +212,26 @@ mod tests {
193212
assert_eq!(arr.value(1), expected_arr.value(1));
194213
assert!(arr.is_null(2));
195214
}
215+
216+
#[test]
217+
fn test_cast_boolean_to_timestamp() {
218+
let timezones: [Option<Arc<str>>; 3] = [
219+
Some(Arc::from("UTC")),
220+
Some(Arc::from("America/Los_Angeles")),
221+
None,
222+
];
223+
224+
for tz in &timezones {
225+
let bool_array: ArrayRef =
226+
Arc::new(BooleanArray::from(vec![Some(true), Some(false), None]));
227+
228+
let result = cast_boolean_to_timestamp(&bool_array, tz).unwrap();
229+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
230+
231+
assert_eq!(ts_array.value(0), 1); // true -> 1 microsecond
232+
assert_eq!(ts_array.value(1), 0); // false -> 0 (epoch)
233+
assert!(ts_array.is_null(2));
234+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
235+
}
236+
}
196237
}

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
// under the License.
1717

1818
use crate::conversion_funcs::boolean::{
19-
cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible,
19+
cast_boolean_to_decimal, cast_boolean_to_timestamp, is_df_cast_from_bool_spark_compatible,
2020
};
2121
use crate::conversion_funcs::numeric::{
22-
cast_float32_to_decimal128, cast_float64_to_decimal128, cast_int_to_decimal128,
23-
cast_int_to_timestamp, is_df_cast_from_decimal_spark_compatible,
24-
is_df_cast_from_float_spark_compatible, is_df_cast_from_int_spark_compatible,
25-
spark_cast_decimal_to_boolean, spark_cast_float32_to_utf8, spark_cast_float64_to_utf8,
26-
spark_cast_int_to_int, spark_cast_nonintegral_numeric_to_integral,
22+
cast_decimal_to_timestamp, cast_float32_to_decimal128, cast_float64_to_decimal128,
23+
cast_float_to_timestamp, cast_int_to_decimal128, cast_int_to_timestamp,
24+
is_df_cast_from_decimal_spark_compatible, is_df_cast_from_float_spark_compatible,
25+
is_df_cast_from_int_spark_compatible, spark_cast_decimal_to_boolean,
26+
spark_cast_float32_to_utf8, spark_cast_float64_to_utf8, spark_cast_int_to_int,
27+
spark_cast_nonintegral_numeric_to_integral,
2728
};
2829
use crate::conversion_funcs::string::{
2930
cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int,
@@ -384,6 +385,9 @@ pub(crate) fn cast_array(
384385
cast_boolean_to_decimal(&array, *precision, *scale)
385386
}
386387
(Int8 | Int16 | Int32 | Int64, Timestamp(_, tz)) => cast_int_to_timestamp(&array, tz),
388+
(Float32 | Float64, Timestamp(_, tz)) => cast_float_to_timestamp(&array, tz, eval_mode),
389+
(Boolean, Timestamp(_, tz)) => cast_boolean_to_timestamp(&array, tz),
390+
(Decimal128(_, scale), Timestamp(_, tz)) => cast_decimal_to_timestamp(&array, tz, *scale),
387391
_ if cast_options.is_adapting_schema
388392
|| is_datafusion_spark_compatible(&from_type, to_type) =>
389393
{

0 commit comments

Comments
 (0)