Skip to content

Commit d79048c

Browse files
authored
Prefetch cache line ahead in loop (#3)
1 parent 627e3f5 commit d79048c

File tree

3 files changed

+106
-50
lines changed

3 files changed

+106
-50
lines changed

examples/escape.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
use string_escape_simd::encode_str;
1+
use string_escape_simd::{encode_str, encode_str_fallback};
22

33
fn main() {
44
let fixture = include_str!("../cal.com.tsx");
5-
encode_str(fixture);
5+
let encoded = encode_str(fixture);
6+
let encoded_fallback = encode_str_fallback(fixture);
7+
assert_eq!(encoded, encoded_fallback);
68
}

rust-toolchain.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[toolchain]
2+
channel = "nightly-2025-05-14"
3+
profile = "default"

src/aarch64.rs

Lines changed: 99 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,116 @@
11
use std::arch::aarch64::{
2-
uint8x16_t, // lane type
3-
vceqq_u8,
4-
vdupq_n_u8, // comparisons / splat
5-
vld1q_u8,
6-
vld1q_u8_x4, // loads
7-
vmaxvq_u8, // horizontal-max reduction
8-
vorrq_u8, // bit-wise OR
9-
vqtbl4q_u8, // table-lookup
2+
vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4, vmaxvq_u8, vorrq_u8, vqtbl4q_u8, vst1q_u8,
103
};
11-
use std::mem::transmute;
124

135
use crate::{encode_str_inner, write_char_escape, CharEscape, ESCAPE, REVERSE_SOLIDUS};
146

15-
const CHUNK_SIZE: usize = 16;
7+
/// Four contiguous 16-byte NEON registers (64 B) per loop.
8+
const CHUNK: usize = 64;
169

1710
pub fn encode_str<S: AsRef<str>>(input: S) -> String {
18-
let input_str = input.as_ref();
19-
let mut output = Vec::with_capacity(input_str.len() + 2);
20-
let bytes = input_str.as_bytes();
21-
let len = bytes.len();
22-
let writer = &mut output;
23-
writer.push(b'"');
24-
// Safety: SIMD instructions
11+
let s = input.as_ref();
12+
let mut out = Vec::with_capacity(s.len() + 2);
13+
let b = s.as_bytes();
14+
let n = b.len();
15+
out.push(b'"');
16+
2517
unsafe {
26-
let mut start = 0;
27-
let table_low = vld1q_u8_x4(ESCAPE[0..64].as_ptr());
28-
let table_high = vdupq_n_u8(b'\\');
29-
while start + CHUNK_SIZE < len {
30-
let current = &bytes[start..start + CHUNK_SIZE];
31-
32-
let chunk = vld1q_u8(current.as_ptr());
33-
let low_mask = vqtbl4q_u8(table_low, chunk);
34-
let high_mask = vceqq_u8(table_high, chunk);
35-
if vmaxvq_u8(low_mask) == 0 && vmaxvq_u8(high_mask) == 0 {
36-
writer.extend_from_slice(current);
37-
start += CHUNK_SIZE;
18+
let tbl = vld1q_u8_x4(ESCAPE.as_ptr()); // first 64 B of the escape table
19+
let slash = vdupq_n_u8(b'\\');
20+
let mut i = 0;
21+
22+
while i + CHUNK <= n {
23+
let ptr = b.as_ptr().add(i);
24+
25+
/* ---- L1 prefetch: one cache line ahead ---- */
26+
core::arch::asm!("prfm pldl1keep, [{0}, #128]", in(reg) ptr);
27+
/* ------------------------------------------ */
28+
29+
// load 64 B (four q-regs)
30+
let a = vld1q_u8(ptr);
31+
let m1 = vqtbl4q_u8(tbl, a);
32+
let m2 = vceqq_u8(slash, a);
33+
34+
let b2 = vld1q_u8(ptr.add(16));
35+
let m3 = vqtbl4q_u8(tbl, b2);
36+
let m4 = vceqq_u8(slash, b2);
37+
38+
let c = vld1q_u8(ptr.add(32));
39+
let m5 = vqtbl4q_u8(tbl, c);
40+
let m6 = vceqq_u8(slash, c);
41+
42+
let d = vld1q_u8(ptr.add(48));
43+
let m7 = vqtbl4q_u8(tbl, d);
44+
let m8 = vceqq_u8(slash, d);
45+
46+
let mask_1 = vorrq_u8(m1, m2);
47+
let mask_2 = vorrq_u8(m3, m4);
48+
let mask_3 = vorrq_u8(m5, m6);
49+
let mask_4 = vorrq_u8(m7, m8);
50+
51+
let mask_r_1 = vmaxvq_u8(mask_1);
52+
let mask_r_2 = vmaxvq_u8(mask_2);
53+
let mask_r_3 = vmaxvq_u8(mask_3);
54+
let mask_r_4 = vmaxvq_u8(mask_4);
55+
56+
// fast path: nothing needs escaping
57+
if mask_r_1 | mask_r_2 | mask_r_3 | mask_r_4 == 0 {
58+
out.extend_from_slice(std::slice::from_raw_parts(ptr, CHUNK));
59+
i += CHUNK;
3860
continue;
3961
}
62+
let mut tmp: [u8; 16] = core::mem::zeroed();
4063

41-
// Vector add the masks to get a single mask
42-
let escape_mask = vorrq_u8(low_mask, high_mask);
43-
let escape_table_mask_slice = transmute::<uint8x16_t, [u8; 16]>(escape_mask);
44-
for (index, value) in escape_table_mask_slice.into_iter().enumerate() {
45-
if value == 0 {
46-
writer.push(bytes[start + index]);
47-
} else if value == 255 {
48-
// value is in the high table mask, which means it's `\`
49-
writer.extend_from_slice(REVERSE_SOLIDUS);
50-
} else {
51-
let char_escape = CharEscape::from_escape_table(value, current[index]);
52-
write_char_escape(writer, char_escape);
53-
}
64+
if mask_r_1 == 0 {
65+
out.extend_from_slice(std::slice::from_raw_parts(ptr, 16));
66+
} else {
67+
vst1q_u8(tmp.as_mut_ptr(), mask_1);
68+
handle_block(&b[i..i + 16], &tmp, &mut out);
5469
}
55-
start += CHUNK_SIZE;
70+
71+
if mask_r_2 == 0 {
72+
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(16), 16));
73+
} else {
74+
vst1q_u8(tmp.as_mut_ptr(), mask_2);
75+
handle_block(&b[i + 16..i + 32], &tmp, &mut out);
76+
}
77+
78+
if mask_r_3 == 0 {
79+
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(32), 16));
80+
} else {
81+
vst1q_u8(tmp.as_mut_ptr(), mask_3);
82+
handle_block(&b[i + 32..i + 48], &tmp, &mut out);
83+
}
84+
85+
if mask_r_4 == 0 {
86+
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(48), 16));
87+
} else {
88+
vst1q_u8(tmp.as_mut_ptr(), mask_4);
89+
handle_block(&b[i + 48..i + 64], &tmp, &mut out);
90+
}
91+
92+
i += CHUNK;
93+
}
94+
if i < n {
95+
encode_str_inner(&b[i..], &mut out);
5696
}
97+
}
98+
out.push(b'"');
99+
// SAFETY: we only emit valid UTF-8
100+
unsafe { String::from_utf8_unchecked(out) }
101+
}
57102

58-
if start < len {
59-
encode_str_inner(&bytes[start..], writer);
103+
#[inline(always)]
104+
unsafe fn handle_block(src: &[u8], mask: &[u8; 16], dst: &mut Vec<u8>) {
105+
for (j, &m) in mask.iter().enumerate() {
106+
let c = src[j];
107+
if m == 0 {
108+
dst.push(c);
109+
} else if m == 0xFF {
110+
dst.extend_from_slice(REVERSE_SOLIDUS);
111+
} else {
112+
let e = CharEscape::from_escape_table(m, c);
113+
write_char_escape(dst, e);
60114
}
61115
}
62-
writer.push(b'"');
63-
// Safety: the bytes are valid UTF-8
64-
unsafe { String::from_utf8_unchecked(output) }
65116
}

0 commit comments

Comments
 (0)