Skip to content

Commit 620c577

Browse files
committed
feat: support url encode expresson
1 parent c62bbb3 commit 620c577

File tree

4 files changed

+336
-3
lines changed

4 files changed

+336
-3
lines changed

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use crate::{
2323
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
2424
spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
2525
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff,
26-
SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace,
26+
SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, SparkUrlEncode,
2727
};
2828
use arrow::datatypes::DataType;
2929
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -198,6 +198,7 @@ fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
198198
Arc::new(ScalarUDF::new_from_impl(SparkMakeDate::default())),
199199
Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())),
200200
Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())),
201+
Arc::new(ScalarUDF::new_from_impl(SparkUrlEncode::default())),
201202
]
202203
}
203204

native/spark-expr/src/string_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
mod contains;
1919
mod string_space;
2020
mod substring;
21+
mod url_encode;
2122

2223
pub use contains::SparkContains;
2324
pub use string_space::SparkStringSpace;
2425
pub use substring::SubstringExpr;
26+
pub use url_encode::SparkUrlEncode;
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{
19+
as_dictionary_array, make_array, Array, ArrayData, ArrayRef, DictionaryArray,
20+
GenericStringArray, OffsetSizeTrait, StringArray,
21+
};
22+
use arrow::buffer::MutableBuffer;
23+
use arrow::datatypes::{DataType, Int32Type};
24+
use datafusion::common::{exec_err, internal_datafusion_err, Result, ScalarValue};
25+
use datafusion::logical_expr::{
26+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
27+
};
28+
use std::{any::Any, sync::Arc};
29+
30+
/// Spark-compatible URL encoding following application/x-www-form-urlencoded format.
31+
/// This matches Java's URLEncoder.encode behavior used by Spark's UrlCodec.encode.
32+
///
33+
/// Key behaviors:
34+
/// - Spaces are encoded as '+' (not '%20')
35+
/// - Alphanumeric characters (a-z, A-Z, 0-9) are not encoded
36+
/// - Special characters '.', '-', '*', '_' are not encoded
37+
/// - All other characters are percent-encoded using UTF-8 bytes
38+
#[derive(Debug, PartialEq, Eq, Hash)]
39+
pub struct SparkUrlEncode {
40+
signature: Signature,
41+
aliases: Vec<String>,
42+
}
43+
44+
impl Default for SparkUrlEncode {
45+
fn default() -> Self {
46+
Self::new()
47+
}
48+
}
49+
50+
impl SparkUrlEncode {
51+
pub fn new() -> Self {
52+
Self {
53+
signature: Signature::user_defined(Volatility::Immutable),
54+
aliases: vec![],
55+
}
56+
}
57+
}
58+
59+
impl ScalarUDFImpl for SparkUrlEncode {
60+
fn as_any(&self) -> &dyn Any {
61+
self
62+
}
63+
64+
fn name(&self) -> &str {
65+
"url_encode"
66+
}
67+
68+
fn signature(&self) -> &Signature {
69+
&self.signature
70+
}
71+
72+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
73+
Ok(match &arg_types[0] {
74+
DataType::Dictionary(key_type, _) => {
75+
DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8))
76+
}
77+
_ => DataType::Utf8,
78+
})
79+
}
80+
81+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
82+
if args.args.len() != 1 {
83+
return Err(internal_datafusion_err!(
84+
"url_encode expects exactly one argument, got {}",
85+
args.args.len()
86+
));
87+
}
88+
let args: [ColumnarValue; 1] = args
89+
.args
90+
.try_into()
91+
.map_err(|_| internal_datafusion_err!("url_encode expects exactly one argument"))?;
92+
spark_url_encode(&args)
93+
}
94+
95+
fn aliases(&self) -> &[String] {
96+
&self.aliases
97+
}
98+
}
99+
100+
pub fn spark_url_encode(args: &[ColumnarValue; 1]) -> Result<ColumnarValue> {
101+
match args {
102+
[ColumnarValue::Array(array)] => {
103+
let result = url_encode_array(array.as_ref())?;
104+
Ok(ColumnarValue::Array(result))
105+
}
106+
[ColumnarValue::Scalar(scalar)] => {
107+
let result = url_encode_scalar(scalar)?;
108+
Ok(ColumnarValue::Scalar(result))
109+
}
110+
}
111+
}
112+
113+
fn url_encode_array(input: &dyn Array) -> Result<ArrayRef> {
114+
match input.data_type() {
115+
DataType::Utf8 => {
116+
let array = input.as_any().downcast_ref::<StringArray>().unwrap();
117+
Ok(url_encode_string_array::<i32>(array))
118+
}
119+
DataType::LargeUtf8 => {
120+
let array = input
121+
.as_any()
122+
.downcast_ref::<GenericStringArray<i64>>()
123+
.unwrap();
124+
Ok(url_encode_string_array::<i64>(array))
125+
}
126+
DataType::Dictionary(_, _) => {
127+
let dict = as_dictionary_array::<Int32Type>(input);
128+
let values = url_encode_array(dict.values())?;
129+
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
130+
Ok(Arc::new(result))
131+
}
132+
other => exec_err!("Unsupported input type for function 'url_encode': {other:?}"),
133+
}
134+
}
135+
136+
fn url_encode_scalar(scalar: &ScalarValue) -> Result<ScalarValue> {
137+
match scalar {
138+
ScalarValue::Utf8(value) | ScalarValue::LargeUtf8(value) => {
139+
let result = value.as_ref().map(|s| url_encode_string(s));
140+
Ok(ScalarValue::Utf8(result))
141+
}
142+
ScalarValue::Null => Ok(ScalarValue::Utf8(None)),
143+
other => exec_err!("Unsupported data type {other:?} for function `url_encode`"),
144+
}
145+
}
146+
147+
fn url_encode_string_array<OffsetSize: OffsetSizeTrait>(
148+
input: &GenericStringArray<OffsetSize>,
149+
) -> ArrayRef {
150+
let array_len = input.len();
151+
let mut offsets = MutableBuffer::new((array_len + 1) * std::mem::size_of::<OffsetSize>());
152+
let mut values = MutableBuffer::new(input.values().len()); // reasonable initial capacity
153+
let mut offset_so_far = OffsetSize::zero();
154+
let null_bit_buffer = input.to_data().nulls().map(|b| b.buffer().clone());
155+
156+
offsets.push(offset_so_far);
157+
158+
for i in 0..array_len {
159+
if !input.is_null(i) {
160+
let encoded = url_encode_string(input.value(i));
161+
offset_so_far += OffsetSize::from_usize(encoded.len()).unwrap();
162+
values.extend_from_slice(encoded.as_bytes());
163+
}
164+
offsets.push(offset_so_far);
165+
}
166+
167+
let data = unsafe {
168+
ArrayData::new_unchecked(
169+
GenericStringArray::<OffsetSize>::DATA_TYPE,
170+
array_len,
171+
None,
172+
null_bit_buffer,
173+
0,
174+
vec![offsets.into(), values.into()],
175+
vec![],
176+
)
177+
};
178+
make_array(data)
179+
}
180+
181+
fn url_encode_length(s: &str) -> usize {
182+
let mut len = 0;
183+
for byte in s.bytes() {
184+
if should_encode(byte) {
185+
if byte == b' ' {
186+
len += 1; // space -> '+'
187+
} else {
188+
len += 3; // other -> %XX
189+
}
190+
} else {
191+
len += 1;
192+
}
193+
}
194+
len
195+
}
196+
197+
fn url_encode_string(s: &str) -> String {
198+
let mut buf = Vec::with_capacity(url_encode_length(s));
199+
for byte in s.bytes() {
200+
if !should_encode(byte) {
201+
buf.push(byte);
202+
} else if byte == b' ' {
203+
buf.push(b'+');
204+
} else {
205+
buf.push(b'%');
206+
buf.push(HEX_BYTES[(byte >> 4) as usize]);
207+
buf.push(HEX_BYTES[(byte & 0x0F) as usize]);
208+
}
209+
}
210+
211+
unsafe { String::from_utf8_unchecked(buf) }
212+
}
213+
214+
const HEX_BYTES: [u8; 16] = *b"0123456789ABCDEF";
215+
216+
/// Check if a byte should be encoded
217+
/// Returns true for characters that need to be percent-encoded
218+
fn should_encode(byte: u8) -> bool {
219+
// Unreserved characters per RFC 3986 that are NOT encoded by URLEncoder:
220+
// - Alphanumeric: A-Z, a-z, 0-9
221+
// - Special: '.', '-', '*', '_'
222+
// Note: '~' is unreserved in RFC 3986 but IS encoded by Java URLEncoder
223+
!matches!(byte,
224+
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' |
225+
b'.' | b'-' | b'*' | b'_'
226+
)
227+
}
228+
229+
230+
#[cfg(test)]
231+
mod tests {
232+
use super::*;
233+
use datafusion::common::cast::as_string_array;
234+
235+
#[test]
236+
fn test_url_encode_basic() {
237+
assert_eq!(url_encode_string("Hello World"), "Hello+World");
238+
assert_eq!(url_encode_string("foo=bar"), "foo%3Dbar");
239+
assert_eq!(url_encode_string("a+b"), "a%2Bb");
240+
assert_eq!(url_encode_string(""), "");
241+
}
242+
243+
#[test]
244+
fn test_url_encode_special_chars() {
245+
assert_eq!(url_encode_string("?"), "%3F");
246+
assert_eq!(url_encode_string("&"), "%26");
247+
assert_eq!(url_encode_string("="), "%3D");
248+
assert_eq!(url_encode_string("#"), "%23");
249+
assert_eq!(url_encode_string("/"), "%2F");
250+
assert_eq!(url_encode_string("%"), "%25");
251+
}
252+
253+
#[test]
254+
fn test_url_encode_unreserved_chars() {
255+
// These should NOT be encoded
256+
assert_eq!(url_encode_string("abc123"), "abc123");
257+
assert_eq!(url_encode_string("ABC"), "ABC");
258+
assert_eq!(url_encode_string("."), ".");
259+
assert_eq!(url_encode_string("-"), "-");
260+
assert_eq!(url_encode_string("*"), "*");
261+
assert_eq!(url_encode_string("_"), "_");
262+
}
263+
264+
#[test]
265+
fn test_url_encode_unicode() {
266+
// UTF-8 multi-byte characters should be percent-encoded
267+
assert_eq!(url_encode_string("cafe\u{0301}"), "cafe%CC%81");
268+
assert_eq!(url_encode_string("\u{00e9}"), "%C3%A9"); // é as single char
269+
}
270+
271+
#[test]
272+
fn test_url_encode_array() {
273+
let input = StringArray::from(vec![
274+
Some("Hello World"),
275+
Some("foo=bar"),
276+
None,
277+
Some(""),
278+
]);
279+
let args = ColumnarValue::Array(Arc::new(input));
280+
match spark_url_encode(&[args]) {
281+
Ok(ColumnarValue::Array(result)) => {
282+
let actual = as_string_array(&result).unwrap();
283+
assert_eq!(actual.value(0), "Hello+World");
284+
assert_eq!(actual.value(1), "foo%3Dbar");
285+
assert!(actual.is_null(2));
286+
assert_eq!(actual.value(3), "");
287+
}
288+
_ => unreachable!(),
289+
}
290+
}
291+
292+
#[test]
293+
fn test_url_encode_scalar() {
294+
let scalar = ScalarValue::Utf8(Some("Hello World".to_string()));
295+
let result = url_encode_scalar(&scalar).unwrap();
296+
assert_eq!(result, ScalarValue::Utf8(Some("Hello+World".to_string())));
297+
298+
let null_scalar = ScalarValue::Utf8(None);
299+
let null_result = url_encode_scalar(&null_scalar).unwrap();
300+
assert_eq!(null_result, ScalarValue::Utf8(None));
301+
}
302+
303+
#[test]
304+
fn test_url_encode_tilde() {
305+
// ~ is unreserved in RFC 3986 but Java URLEncoder encodes it
306+
assert_eq!(url_encode_string("~"), "%7E");
307+
}
308+
}

spark/src/main/scala/org/apache/comet/serde/statics.scala

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919

2020
package org.apache.comet.serde
2121

22-
import org.apache.spark.sql.catalyst.expressions.Attribute
22+
import org.apache.spark.sql.catalyst.expressions.{Attribute, UrlCodec}
2323
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
2424
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
2525

2626
import org.apache.comet.CometSparkSessionExtensions.withInfo
27+
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto}
2728

2829
object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
2930

@@ -34,7 +35,28 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
3435
: Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] =
3536
Map(
3637
("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
37-
"read_side_padding"))
38+
"read_side_padding"),
39+
("encode", UrlCodec.getClass) -> CometUrlEncode)
40+
41+
object CometUrlEncode extends CometExpressionSerde[StaticInvoke] {
42+
override def convert(
43+
expr: StaticInvoke,
44+
inputs: Seq[Attribute],
45+
binding: Boolean): Option[ExprOuterClass.Expr] = {
46+
// StaticInvoke for url_encode may include a second child (the UTF-8 Charset object),
47+
// which is not needed by the Rust backend — it always assumes UTF-8.
48+
// We only convert the first child (the string data).
49+
expr.children match {
50+
case Seq(dataToEncode, _*) =>
51+
val childExpr = exprToProtoInternal(dataToEncode, inputs, binding)
52+
val optExpr = scalarFunctionExprToProto("url_encode", childExpr)
53+
optExprWithInfo(optExpr, expr, dataToEncode)
54+
case _ =>
55+
withInfo(expr, "url_encode expected at least 1 argument but found none")
56+
None
57+
}
58+
}
59+
}
3860

3961
override def convert(
4062
expr: StaticInvoke,

0 commit comments

Comments
 (0)