Skip to content

Commit 6ae689c

Browse files
committed
Re-use Html escaping code to implement JSON escaping
1 parent 84edf1c commit 6ae689c

File tree

4 files changed

+154
-45
lines changed

4 files changed

+154
-45
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ members = [
66
"testing",
77
"testing-alloc",
88
"testing-no-std",
9-
"testing-renamed"
9+
"testing-renamed",
1010
]
1111
resolver = "2"

rinja/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ blocks = ["rinja_derive?/blocks"]
5757
code-in-doc = ["rinja_derive?/code-in-doc"]
5858
config = ["rinja_derive?/config"]
5959
derive = ["rinja_derive"]
60-
serde_json = ["rinja_derive?/serde_json", "dep:serde", "dep:serde_json"]
60+
serde_json = ["std", "rinja_derive?/serde_json", "dep:serde", "dep:serde_json"]
6161
std = [
6262
"alloc",
6363
"rinja_derive?/std",

rinja/src/ascii_str.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,19 @@ impl AsciiChar {
114114
Self::new(ALPHABET[d as usize % ALPHABET.len()]),
115115
]
116116
}
117+
118+
#[inline]
119+
pub const fn two_hex_digits(d: u32) -> [Self; 2] {
120+
const ALPHABET: &[u8; 16] = b"0123456789abcdef";
121+
122+
if d >= ALPHABET.len().pow(2) as u32 {
123+
panic!();
124+
}
125+
[
126+
Self::new(ALPHABET[d as usize / ALPHABET.len()]),
127+
Self::new(ALPHABET[d as usize % ALPHABET.len()]),
128+
]
129+
}
117130
}
118131

119132
mod _ascii_char {

rinja/src/filters/json.rs

Lines changed: 139 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ use std::pin::Pin;
44
use std::{fmt, io, str};
55

66
use serde::Serialize;
7-
use serde_json::ser::{PrettyFormatter, Serializer, to_writer};
7+
use serde_json::ser::{CompactFormatter, PrettyFormatter, Serializer};
88

99
use super::FastWritable;
10+
use crate::ascii_str::{AsciiChar, AsciiStr};
1011

1112
/// Serialize to JSON (requires `json` feature)
1213
///
@@ -187,9 +188,8 @@ where
187188
}
188189

189190
impl<S: Serialize> FastWritable for ToJson<S> {
190-
#[inline]
191191
fn write_into<W: fmt::Write + ?Sized>(&self, f: &mut W) -> crate::Result<()> {
192-
fmt_json(f, &self.value)
192+
serialize(f, &self.value, CompactFormatter)
193193
}
194194
}
195195

@@ -201,9 +201,12 @@ impl<S: Serialize> fmt::Display for ToJson<S> {
201201
}
202202

203203
impl<S: Serialize, I: AsIndent> FastWritable for ToJsonPretty<S, I> {
204-
#[inline]
205204
fn write_into<W: fmt::Write + ?Sized>(&self, f: &mut W) -> crate::Result<()> {
206-
fmt_json_pretty(f, &self.value, self.indent.as_indent())
205+
serialize(
206+
f,
207+
&self.value,
208+
PrettyFormatter::with_indent(self.indent.as_indent().as_bytes()),
209+
)
207210
}
208211
}
209212

@@ -214,58 +217,151 @@ impl<S: Serialize, I: AsIndent> fmt::Display for ToJsonPretty<S, I> {
214217
}
215218
}
216219

217-
fn fmt_json<S: Serialize, W: fmt::Write + ?Sized>(dest: &mut W, value: &S) -> crate::Result<()> {
218-
Ok(to_writer(JsonWriter(dest), value)?)
219-
}
220+
#[inline]
221+
fn serialize<S, W, F>(dest: &mut W, value: &S, formatter: F) -> Result<(), crate::Error>
222+
where
223+
S: Serialize + ?Sized,
224+
W: fmt::Write + ?Sized,
225+
F: serde_json::ser::Formatter,
226+
{
227+
/// The struct must only ever be used with the output of `serde_json`.
228+
/// `serde_json` only produces UTF-8 strings in its `io::Write::write()` calls,
229+
/// and `<JsonWriter as io::Write>` depends on this invariant.
230+
struct JsonWriter<'a, W: fmt::Write + ?Sized>(&'a mut W);
220231

221-
fn fmt_json_pretty<S: Serialize, W: fmt::Write + ?Sized>(
222-
dest: &mut W,
223-
value: &S,
224-
indent: &str,
225-
) -> crate::Result<()> {
226-
let formatter = PrettyFormatter::with_indent(indent.as_bytes());
227-
let mut serializer = Serializer::with_formatter(JsonWriter(dest), formatter);
228-
Ok(value.serialize(&mut serializer)?)
229-
}
232+
impl<W: fmt::Write + ?Sized> io::Write for JsonWriter<'_, W> {
233+
/// Invariant: must be passed valid UTF-8 slices
234+
#[inline]
235+
fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
236+
self.write_all(bytes)?;
237+
Ok(bytes.len())
238+
}
230239

231-
struct JsonWriter<'a, W: fmt::Write + ?Sized>(&'a mut W);
240+
/// Invariant: must be passed valid UTF-8 slices
241+
fn write_all(&mut self, bytes: &[u8]) -> io::Result<()> {
242+
// SAFETY: `serde_json` only writes valid strings
243+
let string = unsafe { std::str::from_utf8_unchecked(bytes) };
244+
write_escaped_str(&mut *self.0, string)
245+
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
246+
}
232247

233-
impl<W: fmt::Write + ?Sized> io::Write for JsonWriter<'_, W> {
234-
#[inline]
235-
fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
236-
self.write_all(bytes)?;
237-
Ok(bytes.len())
248+
#[inline]
249+
fn flush(&mut self) -> io::Result<()> {
250+
Ok(())
251+
}
238252
}
239253

254+
/// Invariant: no character that needs escaping is multi-byte character when encoded in UTF-8;
255+
/// that is true for characters in ASCII range.
240256
#[inline]
241-
fn write_all(&mut self, bytes: &[u8]) -> io::Result<()> {
242-
write(self.0, bytes).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
257+
fn write_escaped_str(dest: &mut (impl fmt::Write + ?Sized), src: &str) -> fmt::Result {
258+
// This implementation reads one byte after another.
259+
// It's not very fast, but should work well enough until portable SIMD gets stabilized.
260+
261+
let mut escaped_buf = ESCAPED_BUF_INIT;
262+
let mut last = 0;
263+
264+
for (index, byte) in src.bytes().enumerate() {
265+
if let Some(escaped) = get_escaped(byte) {
266+
[escaped_buf[4], escaped_buf[5]] = escaped;
267+
write_str_if_nonempty(dest, &src[last..index])?;
268+
dest.write_str(AsciiStr::from_slice(&escaped_buf[..ESCAPED_BUF_LEN]))?;
269+
last = index + 1;
270+
}
271+
}
272+
write_str_if_nonempty(dest, &src[last..])
243273
}
244274

245-
#[inline]
246-
fn flush(&mut self) -> io::Result<()> {
275+
let mut serializer = Serializer::with_formatter(JsonWriter(dest), formatter);
276+
Ok(value.serialize(&mut serializer)?)
277+
}
278+
279+
/// Returns the decimal representation of the codepoint if the character needs HTML escaping.
280+
#[inline]
281+
fn get_escaped(byte: u8) -> Option<[AsciiChar; 2]> {
282+
const _: () = assert!(CHAR_RANGE < 32);
283+
284+
if let MIN_CHAR..=MAX_CHAR = byte {
285+
if (1u32 << (byte - MIN_CHAR)) & BITS != 0 {
286+
return Some(TABLE.0[(byte - MIN_CHAR) as usize]);
287+
}
288+
}
289+
None
290+
}
291+
292+
#[inline(always)]
293+
fn write_str_if_nonempty(output: &mut (impl fmt::Write + ?Sized), input: &str) -> fmt::Result {
294+
if !input.is_empty() {
295+
output.write_str(input)
296+
} else {
247297
Ok(())
248298
}
249299
}
250300

251-
fn write<W: fmt::Write + ?Sized>(f: &mut W, bytes: &[u8]) -> fmt::Result {
252-
let mut last = 0;
253-
for (index, byte) in bytes.iter().enumerate() {
254-
let escaped = match byte {
255-
b'&' => Some(br"\u0026"),
256-
b'\'' => Some(br"\u0027"),
257-
b'<' => Some(br"\u003c"),
258-
b'>' => Some(br"\u003e"),
259-
_ => None,
260-
};
261-
if let Some(escaped) = escaped {
262-
f.write_str(unsafe { str::from_utf8_unchecked(&bytes[last..index]) })?;
263-
f.write_str(unsafe { str::from_utf8_unchecked(escaped) })?;
264-
last = index + 1;
301+
/// List of characters that need HTML escaping, not necessarily in ordinal order.
302+
const CHARS: &[u8] = br#"&'<>"#;
303+
304+
/// The character with the lowest codepoint that needs HTML escaping.
305+
const MIN_CHAR: u8 = {
306+
let mut v = u8::MAX;
307+
let mut i = 0;
308+
while i < CHARS.len() {
309+
if v > CHARS[i] {
310+
v = CHARS[i];
265311
}
312+
i += 1;
266313
}
267-
f.write_str(unsafe { str::from_utf8_unchecked(&bytes[last..]) })
268-
}
314+
v
315+
};
316+
317+
/// The character with the highest codepoint that needs HTML escaping.
318+
const MAX_CHAR: u8 = {
319+
let mut v = u8::MIN;
320+
let mut i = 0;
321+
while i < CHARS.len() {
322+
if v < CHARS[i] {
323+
v = CHARS[i];
324+
}
325+
i += 1;
326+
}
327+
v
328+
};
329+
330+
const BITS: u32 = {
331+
let mut bits = 0;
332+
let mut i = 0;
333+
while i < CHARS.len() {
334+
bits |= 1 << (CHARS[i] - MIN_CHAR);
335+
i += 1;
336+
}
337+
bits
338+
};
339+
340+
/// Number of codepoints between the lowest and highest character that needs escaping, incl.
341+
const CHAR_RANGE: usize = (MAX_CHAR - MIN_CHAR + 1) as usize;
342+
343+
#[repr(align(64))]
344+
struct Table([[AsciiChar; 2]; CHAR_RANGE]);
345+
346+
/// For characters that need HTML escaping, the codepoint is formatted as decimal digits,
347+
/// otherwise `b"\0\0"`. Starting at [`MIN_CHAR`].
348+
const TABLE: &Table = &{
349+
let mut table = Table([UNESCAPED; CHAR_RANGE]);
350+
let mut i = 0;
351+
while i < CHARS.len() {
352+
let c = CHARS[i];
353+
table.0[c as u32 as usize - MIN_CHAR as usize] = AsciiChar::two_hex_digits(c as u32);
354+
i += 1;
355+
}
356+
table
357+
};
358+
359+
const UNESCAPED: [AsciiChar; 2] = AsciiStr::new_sized("");
360+
361+
const ESCAPED_BUF_INIT_UNPADDED: &str = "\\u00__";
362+
// RATIONALE: llvm generates better code if the buffer is register sized
363+
const ESCAPED_BUF_INIT: [AsciiChar; 8] = AsciiStr::new_sized(ESCAPED_BUF_INIT_UNPADDED);
364+
const ESCAPED_BUF_LEN: usize = ESCAPED_BUF_INIT_UNPADDED.len();
269365

270366
#[cfg(all(test, feature = "alloc"))]
271367
mod tests {

0 commit comments

Comments
 (0)