1+ // Licensed to the Apache Software Foundation (ASF) under one
2+ // or more contributor license agreements. See the NOTICE file
3+ // distributed with this work for additional information
4+ // regarding copyright ownership. The ASF licenses this file
5+ // to you under the Apache License, Version 2.0 (the
6+ // "License"); you may not use this file except in compliance
7+ // with the License. You may obtain a copy of the License at
8+ //
9+ // http://www.apache.org/licenses/LICENSE-2.0
10+ //
11+ // Unless required by applicable law or agreed to in writing,
12+ // software distributed under the License is distributed on an
13+ // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+ // KIND, either express or implied. See the License for the
15+ // specific language governing permissions and limitations
16+ // under the License.
17+
18+ use arrow:: array:: {
19+ as_dictionary_array, make_array, Array , ArrayData , ArrayRef , DictionaryArray ,
20+ GenericStringArray , OffsetSizeTrait , StringArray ,
21+ } ;
22+ use arrow:: buffer:: MutableBuffer ;
23+ use arrow:: datatypes:: { DataType , Int32Type } ;
24+ use datafusion:: common:: { exec_err, internal_datafusion_err, Result , ScalarValue } ;
25+ use datafusion:: logical_expr:: {
26+ ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , Volatility ,
27+ } ;
28+ use std:: { any:: Any , sync:: Arc } ;
29+
30+ /// Spark-compatible URL encoding following application/x-www-form-urlencoded format.
31+ /// This matches Java's URLEncoder.encode behavior used by Spark's UrlCodec.encode.
32+ ///
33+ /// Key behaviors:
34+ /// - Spaces are encoded as '+' (not '%20')
35+ /// - Alphanumeric characters (a-z, A-Z, 0-9) are not encoded
36+ /// - Special characters '.', '-', '*', '_' are not encoded
37+ /// - All other characters are percent-encoded using UTF-8 bytes
38+ #[ derive( Debug , PartialEq , Eq , Hash ) ]
39+ pub struct SparkUrlEncode {
40+ signature : Signature ,
41+ aliases : Vec < String > ,
42+ }
43+
44+ impl Default for SparkUrlEncode {
45+ fn default ( ) -> Self {
46+ Self :: new ( )
47+ }
48+ }
49+
50+ impl SparkUrlEncode {
51+ pub fn new ( ) -> Self {
52+ Self {
53+ signature : Signature :: user_defined ( Volatility :: Immutable ) ,
54+ aliases : vec ! [ ] ,
55+ }
56+ }
57+ }
58+
59+ impl ScalarUDFImpl for SparkUrlEncode {
60+ fn as_any ( & self ) -> & dyn Any {
61+ self
62+ }
63+
64+ fn name ( & self ) -> & str {
65+ "url_encode"
66+ }
67+
68+ fn signature ( & self ) -> & Signature {
69+ & self . signature
70+ }
71+
72+ fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
73+ Ok ( match & arg_types[ 0 ] {
74+ DataType :: Dictionary ( key_type, _) => {
75+ DataType :: Dictionary ( key_type. clone ( ) , Box :: new ( DataType :: Utf8 ) )
76+ }
77+ _ => DataType :: Utf8 ,
78+ } )
79+ }
80+
81+ fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
82+ if args. args . len ( ) != 1 {
83+ return Err ( internal_datafusion_err ! (
84+ "url_encode expects exactly one argument, got {}" ,
85+ args. args. len( )
86+ ) ) ;
87+ }
88+ let args: [ ColumnarValue ; 1 ] = args
89+ . args
90+ . try_into ( )
91+ . map_err ( |_| internal_datafusion_err ! ( "url_encode expects exactly one argument" ) ) ?;
92+ spark_url_encode ( & args)
93+ }
94+
95+ fn aliases ( & self ) -> & [ String ] {
96+ & self . aliases
97+ }
98+ }
99+
100+ pub fn spark_url_encode ( args : & [ ColumnarValue ; 1 ] ) -> Result < ColumnarValue > {
101+ match args {
102+ [ ColumnarValue :: Array ( array) ] => {
103+ let result = url_encode_array ( array. as_ref ( ) ) ?;
104+ Ok ( ColumnarValue :: Array ( result) )
105+ }
106+ [ ColumnarValue :: Scalar ( scalar) ] => {
107+ let result = url_encode_scalar ( scalar) ?;
108+ Ok ( ColumnarValue :: Scalar ( result) )
109+ }
110+ }
111+ }
112+
113+ fn url_encode_array ( input : & dyn Array ) -> Result < ArrayRef > {
114+ match input. data_type ( ) {
115+ DataType :: Utf8 => {
116+ let array = input. as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
117+ Ok ( url_encode_string_array :: < i32 > ( array) )
118+ }
119+ DataType :: LargeUtf8 => {
120+ let array = input
121+ . as_any ( )
122+ . downcast_ref :: < GenericStringArray < i64 > > ( )
123+ . unwrap ( ) ;
124+ Ok ( url_encode_string_array :: < i64 > ( array) )
125+ }
126+ DataType :: Dictionary ( _, _) => {
127+ let dict = as_dictionary_array :: < Int32Type > ( input) ;
128+ let values = url_encode_array ( dict. values ( ) ) ?;
129+ let result = DictionaryArray :: try_new ( dict. keys ( ) . clone ( ) , values) ?;
130+ Ok ( Arc :: new ( result) )
131+ }
132+ other => exec_err ! ( "Unsupported input type for function 'url_encode': {other:?}" ) ,
133+ }
134+ }
135+
136+ fn url_encode_scalar ( scalar : & ScalarValue ) -> Result < ScalarValue > {
137+ match scalar {
138+ ScalarValue :: Utf8 ( value) | ScalarValue :: LargeUtf8 ( value) => {
139+ let result = value. as_ref ( ) . map ( |s| url_encode_string ( s) ) ;
140+ Ok ( ScalarValue :: Utf8 ( result) )
141+ }
142+ ScalarValue :: Null => Ok ( ScalarValue :: Utf8 ( None ) ) ,
143+ other => exec_err ! ( "Unsupported data type {other:?} for function `url_encode`" ) ,
144+ }
145+ }
146+
147+ fn url_encode_string_array < OffsetSize : OffsetSizeTrait > (
148+ input : & GenericStringArray < OffsetSize > ,
149+ ) -> ArrayRef {
150+ let array_len = input. len ( ) ;
151+ let mut offsets = MutableBuffer :: new ( ( array_len + 1 ) * std:: mem:: size_of :: < OffsetSize > ( ) ) ;
152+ let mut values = MutableBuffer :: new ( input. values ( ) . len ( ) ) ; // reasonable initial capacity
153+ let mut offset_so_far = OffsetSize :: zero ( ) ;
154+ let null_bit_buffer = input. to_data ( ) . nulls ( ) . map ( |b| b. buffer ( ) . clone ( ) ) ;
155+
156+ offsets. push ( offset_so_far) ;
157+
158+ for i in 0 ..array_len {
159+ if !input. is_null ( i) {
160+ let encoded = url_encode_string ( input. value ( i) ) ;
161+ offset_so_far += OffsetSize :: from_usize ( encoded. len ( ) ) . unwrap ( ) ;
162+ values. extend_from_slice ( encoded. as_bytes ( ) ) ;
163+ }
164+ offsets. push ( offset_so_far) ;
165+ }
166+
167+ let data = unsafe {
168+ ArrayData :: new_unchecked (
169+ GenericStringArray :: < OffsetSize > :: DATA_TYPE ,
170+ array_len,
171+ None ,
172+ null_bit_buffer,
173+ 0 ,
174+ vec ! [ offsets. into( ) , values. into( ) ] ,
175+ vec ! [ ] ,
176+ )
177+ } ;
178+ make_array ( data)
179+ }
180+
181+ fn url_encode_length ( s : & str ) -> usize {
182+ let mut len = 0 ;
183+ for byte in s. bytes ( ) {
184+ if should_encode ( byte) {
185+ if byte == b' ' {
186+ len += 1 ; // space -> '+'
187+ } else {
188+ len += 3 ; // other -> %XX
189+ }
190+ } else {
191+ len += 1 ;
192+ }
193+ }
194+ len
195+ }
196+
197+ fn url_encode_string ( s : & str ) -> String {
198+ let mut buf = Vec :: with_capacity ( url_encode_length ( s) ) ;
199+ for byte in s. bytes ( ) {
200+ if !should_encode ( byte) {
201+ buf. push ( byte) ;
202+ } else if byte == b' ' {
203+ buf. push ( b'+' ) ;
204+ } else {
205+ buf. push ( b'%' ) ;
206+ buf. push ( HEX_BYTES [ ( byte >> 4 ) as usize ] ) ;
207+ buf. push ( HEX_BYTES [ ( byte & 0x0F ) as usize ] ) ;
208+ }
209+ }
210+
211+ unsafe { String :: from_utf8_unchecked ( buf) }
212+ }
213+
214+ const HEX_BYTES : [ u8 ; 16 ] = * b"0123456789ABCDEF" ;
215+
216+ /// Check if a byte should be encoded
217+ /// Returns true for characters that need to be percent-encoded
218+ fn should_encode ( byte : u8 ) -> bool {
219+ // Unreserved characters per RFC 3986 that are NOT encoded by URLEncoder:
220+ // - Alphanumeric: A-Z, a-z, 0-9
221+ // - Special: '.', '-', '*', '_'
222+ // Note: '~' is unreserved in RFC 3986 but IS encoded by Java URLEncoder
223+ !matches ! ( byte,
224+ b'A' ..=b'Z' | b'a' ..=b'z' | b'0' ..=b'9' |
225+ b'.' | b'-' | b'*' | b'_'
226+ )
227+ }
228+
229+
230+ #[ cfg( test) ]
231+ mod tests {
232+ use super :: * ;
233+ use datafusion:: common:: cast:: as_string_array;
234+
235+ #[ test]
236+ fn test_url_encode_basic ( ) {
237+ assert_eq ! ( url_encode_string( "Hello World" ) , "Hello+World" ) ;
238+ assert_eq ! ( url_encode_string( "foo=bar" ) , "foo%3Dbar" ) ;
239+ assert_eq ! ( url_encode_string( "a+b" ) , "a%2Bb" ) ;
240+ assert_eq ! ( url_encode_string( "" ) , "" ) ;
241+ }
242+
243+ #[ test]
244+ fn test_url_encode_special_chars ( ) {
245+ assert_eq ! ( url_encode_string( "?" ) , "%3F" ) ;
246+ assert_eq ! ( url_encode_string( "&" ) , "%26" ) ;
247+ assert_eq ! ( url_encode_string( "=" ) , "%3D" ) ;
248+ assert_eq ! ( url_encode_string( "#" ) , "%23" ) ;
249+ assert_eq ! ( url_encode_string( "/" ) , "%2F" ) ;
250+ assert_eq ! ( url_encode_string( "%" ) , "%25" ) ;
251+ }
252+
253+ #[ test]
254+ fn test_url_encode_unreserved_chars ( ) {
255+ // These should NOT be encoded
256+ assert_eq ! ( url_encode_string( "abc123" ) , "abc123" ) ;
257+ assert_eq ! ( url_encode_string( "ABC" ) , "ABC" ) ;
258+ assert_eq ! ( url_encode_string( "." ) , "." ) ;
259+ assert_eq ! ( url_encode_string( "-" ) , "-" ) ;
260+ assert_eq ! ( url_encode_string( "*" ) , "*" ) ;
261+ assert_eq ! ( url_encode_string( "_" ) , "_" ) ;
262+ }
263+
264+ #[ test]
265+ fn test_url_encode_unicode ( ) {
266+ // UTF-8 multi-byte characters should be percent-encoded
267+ assert_eq ! ( url_encode_string( "cafe\u{0301} " ) , "cafe%CC%81" ) ;
268+ assert_eq ! ( url_encode_string( "\u{00e9} " ) , "%C3%A9" ) ; // é as single char
269+ }
270+
271+ #[ test]
272+ fn test_url_encode_array ( ) {
273+ let input = StringArray :: from ( vec ! [
274+ Some ( "Hello World" ) ,
275+ Some ( "foo=bar" ) ,
276+ None ,
277+ Some ( "" ) ,
278+ ] ) ;
279+ let args = ColumnarValue :: Array ( Arc :: new ( input) ) ;
280+ match spark_url_encode ( & [ args] ) {
281+ Ok ( ColumnarValue :: Array ( result) ) => {
282+ let actual = as_string_array ( & result) . unwrap ( ) ;
283+ assert_eq ! ( actual. value( 0 ) , "Hello+World" ) ;
284+ assert_eq ! ( actual. value( 1 ) , "foo%3Dbar" ) ;
285+ assert ! ( actual. is_null( 2 ) ) ;
286+ assert_eq ! ( actual. value( 3 ) , "" ) ;
287+ }
288+ _ => unreachable ! ( ) ,
289+ }
290+ }
291+
292+ #[ test]
293+ fn test_url_encode_scalar ( ) {
294+ let scalar = ScalarValue :: Utf8 ( Some ( "Hello World" . to_string ( ) ) ) ;
295+ let result = url_encode_scalar ( & scalar) . unwrap ( ) ;
296+ assert_eq ! ( result, ScalarValue :: Utf8 ( Some ( "Hello+World" . to_string( ) ) ) ) ;
297+
298+ let null_scalar = ScalarValue :: Utf8 ( None ) ;
299+ let null_result = url_encode_scalar ( & null_scalar) . unwrap ( ) ;
300+ assert_eq ! ( null_result, ScalarValue :: Utf8 ( None ) ) ;
301+ }
302+
303+ #[ test]
304+ fn test_url_encode_tilde ( ) {
305+ // ~ is unreserved in RFC 3986 but Java URLEncoder encodes it
306+ assert_eq ! ( url_encode_string( "~" ) , "%7E" ) ;
307+ }
308+ }
0 commit comments