Skip to content

Prefetch cache line ahead in loop #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/escape.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use string_escape_simd::encode_str;
use string_escape_simd::{encode_str, encode_str_fallback};

fn main() {
let fixture = include_str!("../cal.com.tsx");
encode_str(fixture);
let encoded = encode_str(fixture);
let encoded_fallback = encode_str_fallback(fixture);
assert_eq!(encoded, encoded_fallback);
}
3 changes: 3 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[toolchain]
channel = "nightly-2025-05-14"
profile = "default"
147 changes: 99 additions & 48 deletions src/aarch64.rs
Original file line number Diff line number Diff line change
@@ -1,65 +1,116 @@
use std::arch::aarch64::{
uint8x16_t, // lane type
vceqq_u8,
vdupq_n_u8, // comparisons / splat
vld1q_u8,
vld1q_u8_x4, // loads
vmaxvq_u8, // horizontal-max reduction
vorrq_u8, // bit-wise OR
vqtbl4q_u8, // table-lookup
vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4, vmaxvq_u8, vorrq_u8, vqtbl4q_u8, vst1q_u8,
};
use std::mem::transmute;

use crate::{encode_str_inner, write_char_escape, CharEscape, ESCAPE, REVERSE_SOLIDUS};

const CHUNK_SIZE: usize = 16;
/// Four contiguous 16-byte NEON registers (64 B) per loop.
const CHUNK: usize = 64;

pub fn encode_str<S: AsRef<str>>(input: S) -> String {
let input_str = input.as_ref();
let mut output = Vec::with_capacity(input_str.len() + 2);
let bytes = input_str.as_bytes();
let len = bytes.len();
let writer = &mut output;
writer.push(b'"');
// Safety: SIMD instructions
let s = input.as_ref();
let mut out = Vec::with_capacity(s.len() + 2);
let b = s.as_bytes();
let n = b.len();
out.push(b'"');

unsafe {
let mut start = 0;
let table_low = vld1q_u8_x4(ESCAPE[0..64].as_ptr());
let table_high = vdupq_n_u8(b'\\');
while start + CHUNK_SIZE < len {
let current = &bytes[start..start + CHUNK_SIZE];

let chunk = vld1q_u8(current.as_ptr());
let low_mask = vqtbl4q_u8(table_low, chunk);
let high_mask = vceqq_u8(table_high, chunk);
if vmaxvq_u8(low_mask) == 0 && vmaxvq_u8(high_mask) == 0 {
writer.extend_from_slice(current);
start += CHUNK_SIZE;
let tbl = vld1q_u8_x4(ESCAPE.as_ptr()); // first 64 B of the escape table
let slash = vdupq_n_u8(b'\\');
let mut i = 0;

while i + CHUNK <= n {
let ptr = b.as_ptr().add(i);

/* ---- L1 prefetch: one cache line ahead ---- */
core::arch::asm!("prfm pldl1keep, [{0}, #128]", in(reg) ptr);
/* ------------------------------------------ */

// load 64 B (four q-regs)
let a = vld1q_u8(ptr);
let m1 = vqtbl4q_u8(tbl, a);
let m2 = vceqq_u8(slash, a);

let b2 = vld1q_u8(ptr.add(16));
let m3 = vqtbl4q_u8(tbl, b2);
let m4 = vceqq_u8(slash, b2);

let c = vld1q_u8(ptr.add(32));
let m5 = vqtbl4q_u8(tbl, c);
let m6 = vceqq_u8(slash, c);

let d = vld1q_u8(ptr.add(48));
let m7 = vqtbl4q_u8(tbl, d);
let m8 = vceqq_u8(slash, d);

let mask_1 = vorrq_u8(m1, m2);
let mask_2 = vorrq_u8(m3, m4);
let mask_3 = vorrq_u8(m5, m6);
let mask_4 = vorrq_u8(m7, m8);

let mask_r_1 = vmaxvq_u8(mask_1);
let mask_r_2 = vmaxvq_u8(mask_2);
let mask_r_3 = vmaxvq_u8(mask_3);
let mask_r_4 = vmaxvq_u8(mask_4);

// fast path: nothing needs escaping
if mask_r_1 | mask_r_2 | mask_r_3 | mask_r_4 == 0 {
out.extend_from_slice(std::slice::from_raw_parts(ptr, CHUNK));
i += CHUNK;
continue;
}
let mut tmp: [u8; 16] = core::mem::zeroed();

// Vector add the masks to get a single mask
let escape_mask = vorrq_u8(low_mask, high_mask);
let escape_table_mask_slice = transmute::<uint8x16_t, [u8; 16]>(escape_mask);
for (index, value) in escape_table_mask_slice.into_iter().enumerate() {
if value == 0 {
writer.push(bytes[start + index]);
} else if value == 255 {
// value is in the high table mask, which means it's `\`
writer.extend_from_slice(REVERSE_SOLIDUS);
} else {
let char_escape = CharEscape::from_escape_table(value, current[index]);
write_char_escape(writer, char_escape);
}
if mask_r_1 == 0 {
out.extend_from_slice(std::slice::from_raw_parts(ptr, 16));
} else {
vst1q_u8(tmp.as_mut_ptr(), mask_1);
handle_block(&b[i..i + 16], &tmp, &mut out);
}
start += CHUNK_SIZE;

if mask_r_2 == 0 {
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(16), 16));
} else {
vst1q_u8(tmp.as_mut_ptr(), mask_2);
handle_block(&b[i + 16..i + 32], &tmp, &mut out);
}

if mask_r_3 == 0 {
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(32), 16));
} else {
vst1q_u8(tmp.as_mut_ptr(), mask_3);
handle_block(&b[i + 32..i + 48], &tmp, &mut out);
}

if mask_r_4 == 0 {
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(48), 16));
} else {
vst1q_u8(tmp.as_mut_ptr(), mask_4);
handle_block(&b[i + 48..i + 64], &tmp, &mut out);
}

i += CHUNK;
}
if i < n {
encode_str_inner(&b[i..], &mut out);
}
}
out.push(b'"');
// SAFETY: we only emit valid UTF-8
unsafe { String::from_utf8_unchecked(out) }
}

if start < len {
encode_str_inner(&bytes[start..], writer);
#[inline(always)]
unsafe fn handle_block(src: &[u8], mask: &[u8; 16], dst: &mut Vec<u8>) {
for (j, &m) in mask.iter().enumerate() {
let c = src[j];
if m == 0 {
dst.push(c);
} else if m == 0xFF {
dst.extend_from_slice(REVERSE_SOLIDUS);
} else {
let e = CharEscape::from_escape_table(m, c);
write_char_escape(dst, e);
}
}
writer.push(b'"');
// Safety: the bytes are valid UTF-8
unsafe { String::from_utf8_unchecked(output) }
}