From b0a8c508f47733893ad64307d516aaab94b0c774 Mon Sep 17 00:00:00 2001 From: LongYinan Date: Wed, 14 May 2025 18:01:19 +0800 Subject: [PATCH] Prefetch cache line ahead in loop --- examples/escape.rs | 6 +- rust-toolchain.toml | 3 + src/aarch64.rs | 147 +++++++++++++++++++++++++++++--------------- 3 files changed, 106 insertions(+), 50 deletions(-) create mode 100644 rust-toolchain.toml diff --git a/examples/escape.rs b/examples/escape.rs index 93d6fe4..6343880 100644 --- a/examples/escape.rs +++ b/examples/escape.rs @@ -1,6 +1,8 @@ -use string_escape_simd::encode_str; +use string_escape_simd::{encode_str, encode_str_fallback}; fn main() { let fixture = include_str!("../cal.com.tsx"); - encode_str(fixture); + let encoded = encode_str(fixture); + let encoded_fallback = encode_str_fallback(fixture); + assert_eq!(encoded, encoded_fallback); } diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..7f43b88 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "nightly-2025-05-14" +profile = "default" diff --git a/src/aarch64.rs b/src/aarch64.rs index c6c5ba2..5caf773 100644 --- a/src/aarch64.rs +++ b/src/aarch64.rs @@ -1,65 +1,116 @@ use std::arch::aarch64::{ - uint8x16_t, // lane type - vceqq_u8, - vdupq_n_u8, // comparisons / splat - vld1q_u8, - vld1q_u8_x4, // loads - vmaxvq_u8, // horizontal-max reduction - vorrq_u8, // bit-wise OR - vqtbl4q_u8, // table-lookup + vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4, vmaxvq_u8, vorrq_u8, vqtbl4q_u8, vst1q_u8, }; -use std::mem::transmute; use crate::{encode_str_inner, write_char_escape, CharEscape, ESCAPE, REVERSE_SOLIDUS}; -const CHUNK_SIZE: usize = 16; +/// Four contiguous 16-byte NEON registers (64 B) per loop. +const CHUNK: usize = 64; pub fn encode_str>(input: S) -> String { - let input_str = input.as_ref(); - let mut output = Vec::with_capacity(input_str.len() + 2); - let bytes = input_str.as_bytes(); - let len = bytes.len(); - let writer = &mut output; - writer.push(b'"'); - // Safety: SIMD instructions + let s = input.as_ref(); + let mut out = Vec::with_capacity(s.len() + 2); + let b = s.as_bytes(); + let n = b.len(); + out.push(b'"'); + unsafe { - let mut start = 0; - let table_low = vld1q_u8_x4(ESCAPE[0..64].as_ptr()); - let table_high = vdupq_n_u8(b'\\'); - while start + CHUNK_SIZE < len { - let current = &bytes[start..start + CHUNK_SIZE]; - - let chunk = vld1q_u8(current.as_ptr()); - let low_mask = vqtbl4q_u8(table_low, chunk); - let high_mask = vceqq_u8(table_high, chunk); - if vmaxvq_u8(low_mask) == 0 && vmaxvq_u8(high_mask) == 0 { - writer.extend_from_slice(current); - start += CHUNK_SIZE; + let tbl = vld1q_u8_x4(ESCAPE.as_ptr()); // first 64 B of the escape table + let slash = vdupq_n_u8(b'\\'); + let mut i = 0; + + while i + CHUNK <= n { + let ptr = b.as_ptr().add(i); + + /* ---- L1 prefetch: one cache line ahead ---- */ + core::arch::asm!("prfm pldl1keep, [{0}, #128]", in(reg) ptr); + /* ------------------------------------------ */ + + // load 64 B (four q-regs) + let a = vld1q_u8(ptr); + let m1 = vqtbl4q_u8(tbl, a); + let m2 = vceqq_u8(slash, a); + + let b2 = vld1q_u8(ptr.add(16)); + let m3 = vqtbl4q_u8(tbl, b2); + let m4 = vceqq_u8(slash, b2); + + let c = vld1q_u8(ptr.add(32)); + let m5 = vqtbl4q_u8(tbl, c); + let m6 = vceqq_u8(slash, c); + + let d = vld1q_u8(ptr.add(48)); + let m7 = vqtbl4q_u8(tbl, d); + let m8 = vceqq_u8(slash, d); + + let mask_1 = vorrq_u8(m1, m2); + let mask_2 = vorrq_u8(m3, m4); + let mask_3 = vorrq_u8(m5, m6); + let mask_4 = vorrq_u8(m7, m8); + + let mask_r_1 = vmaxvq_u8(mask_1); + let mask_r_2 = vmaxvq_u8(mask_2); + let mask_r_3 = vmaxvq_u8(mask_3); + let mask_r_4 = vmaxvq_u8(mask_4); + + // fast path: nothing needs escaping + if mask_r_1 | mask_r_2 | mask_r_3 | mask_r_4 == 0 { + out.extend_from_slice(std::slice::from_raw_parts(ptr, CHUNK)); + i += CHUNK; continue; } + let mut tmp: [u8; 16] = core::mem::zeroed(); - // Vector add the masks to get a single mask - let escape_mask = vorrq_u8(low_mask, high_mask); - let escape_table_mask_slice = transmute::(escape_mask); - for (index, value) in escape_table_mask_slice.into_iter().enumerate() { - if value == 0 { - writer.push(bytes[start + index]); - } else if value == 255 { - // value is in the high table mask, which means it's `\` - writer.extend_from_slice(REVERSE_SOLIDUS); - } else { - let char_escape = CharEscape::from_escape_table(value, current[index]); - write_char_escape(writer, char_escape); - } + if mask_r_1 == 0 { + out.extend_from_slice(std::slice::from_raw_parts(ptr, 16)); + } else { + vst1q_u8(tmp.as_mut_ptr(), mask_1); + handle_block(&b[i..i + 16], &tmp, &mut out); } - start += CHUNK_SIZE; + + if mask_r_2 == 0 { + out.extend_from_slice(std::slice::from_raw_parts(ptr.add(16), 16)); + } else { + vst1q_u8(tmp.as_mut_ptr(), mask_2); + handle_block(&b[i + 16..i + 32], &tmp, &mut out); + } + + if mask_r_3 == 0 { + out.extend_from_slice(std::slice::from_raw_parts(ptr.add(32), 16)); + } else { + vst1q_u8(tmp.as_mut_ptr(), mask_3); + handle_block(&b[i + 32..i + 48], &tmp, &mut out); + } + + if mask_r_4 == 0 { + out.extend_from_slice(std::slice::from_raw_parts(ptr.add(48), 16)); + } else { + vst1q_u8(tmp.as_mut_ptr(), mask_4); + handle_block(&b[i + 48..i + 64], &tmp, &mut out); + } + + i += CHUNK; + } + if i < n { + encode_str_inner(&b[i..], &mut out); } + } + out.push(b'"'); + // SAFETY: we only emit valid UTF-8 + unsafe { String::from_utf8_unchecked(out) } +} - if start < len { - encode_str_inner(&bytes[start..], writer); +#[inline(always)] +unsafe fn handle_block(src: &[u8], mask: &[u8; 16], dst: &mut Vec) { + for (j, &m) in mask.iter().enumerate() { + let c = src[j]; + if m == 0 { + dst.push(c); + } else if m == 0xFF { + dst.extend_from_slice(REVERSE_SOLIDUS); + } else { + let e = CharEscape::from_escape_table(m, c); + write_char_escape(dst, e); } } - writer.push(b'"'); - // Safety: the bytes are valid UTF-8 - unsafe { String::from_utf8_unchecked(output) } }