Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- use std::arch::x86_64::*;
- use std::fmt;
- use std::mem::size_of;
- use std::str;
- const VECTOR_SIZE: usize = size_of::<__m256i>();
- const VECTOR_ALIGN: usize = VECTOR_SIZE - 1;
- const LOOP_SIZE: usize = 2 * VECTOR_SIZE;
- const FLAG: u8 = b'>' - b'"';
- pub fn escape(s: &str) -> Escaped {
- Escaped {
- bytes: s.as_bytes(),
- }
- }
- macro_rules! escaping_body {
- ($start:ident, $i:ident, $fmt:ident, $_self:ident, $quote:expr) => {{
- if $start < $i {
- $fmt.write_str(str::from_utf8_unchecked(&$_self.bytes[$start..$i]))?;
- }
- $fmt.write_str($quote)?;
- $start = $i + 1;
- }};
- }
- macro_rules! escaping_bodies {
- ($start:ident, $b: ident, $i:ident, $fmt:ident, $_self:ident) => {{
- match $b {
- b'<' => escaping_body!($start, $i, $fmt, $_self, "<"),
- b'>' => escaping_body!($start, $i, $fmt, $_self, ">"),
- b'&' => escaping_body!($start, $i, $fmt, $_self, "&"),
- b'"' => escaping_body!($start, $i, $fmt, $_self, """),
- b'\'' => escaping_body!($start, $i, $fmt, $_self, "'"),
- b'/' => escaping_body!($start, $i, $fmt, $_self, "/"),
- _ => (),
- }
- }};
- }
- macro_rules! write_mask {
- ($cmp: ident, $ptr: ident, $start_ptr: ident, $start: ident, $fmt: ident, $_self:ident) => {{
- let mut mask = _mm256_movemask_epi8($cmp);
- if mask != 0 {
- let at = sub($ptr, $start_ptr);
- let mut pos = mask.trailing_zeros() as usize;
- loop {
- let b = *$ptr.add(pos);
- let at = at + pos;
- escaping_bodies!($start, b, at, $fmt, $_self);
- debug_assert!(32 > pos && 0 <= pos);
- mask ^= 1 << pos;
- if mask == 0 {
- break;
- }
- pos = mask.trailing_zeros() as usize;
- }
- debug_assert!(at == sub($ptr, $start_ptr));
- }
- }};
- }
- pub struct Escaped<'a> {
- bytes: &'a [u8],
- }
- impl<'a> fmt::Display for Escaped<'a> {
- fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
- let len = self.bytes.len();
- let mut start = 0;
- unsafe {
- if len < VECTOR_SIZE {
- for (i, b) in self.bytes.iter().enumerate() {
- if b.wrapping_sub(b'"') <= FLAG {
- let b = *b;
- escaping_bodies!(start, b, i, fmt, self);
- }
- }
- fmt.write_str(str::from_utf8_unchecked(&self.bytes[start..]))?;
- return Ok(());
- }
- let v_flag = _mm256_set1_epi8((FLAG + 1) as i8);
- let v_flag_below = _mm256_set1_epi8(b'"' as i8);
- let start_ptr = self.bytes.as_ptr();
- let end_ptr = self.bytes[len..].as_ptr();
- let mut ptr = start_ptr;
- let align = VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN);
- // Align
- for i in 0..align {
- let b = *ptr;
- if b.wrapping_sub(b'"') <= FLAG {
- escaping_bodies!(start, b, i, fmt, self);
- }
- ptr = ptr.add(1);
- }
- debug_assert!(start <= sub(ptr, start_ptr));
- debug_assert!(ptr > start_ptr && end_ptr.sub(VECTOR_SIZE) >= start_ptr);
- // body
- if LOOP_SIZE <= len {
- while ptr <= end_ptr.sub(LOOP_SIZE) {
- debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE);
- let a = _mm256_load_si256(ptr as *const __m256i);
- let b = _mm256_load_si256(ptr.add(VECTOR_SIZE) as *const __m256i);
- let cmp_a = _mm256_cmpgt_epi8(v_flag, _mm256_sub_epi8(a, v_flag_below));
- let cmp_b = _mm256_cmpgt_epi8(v_flag, _mm256_sub_epi8(b, v_flag_below));
- let mask = _mm256_movemask_epi8(_mm256_or_si256(cmp_a, cmp_b));
- if mask != 0 {
- debug_assert!(start <= sub(ptr, start_ptr));
- write_mask!(cmp_a, ptr, start_ptr, start, fmt, self);
- let ptr = ptr.add(VECTOR_SIZE);
- debug_assert!(ptr <= end_ptr);
- write_mask!(cmp_b, ptr, start_ptr, start, fmt, self);
- }
- ptr = ptr.add(LOOP_SIZE);
- debug_assert!(start <= sub(ptr, start_ptr));
- }
- }
- // n - 1 loop
- while ptr <= end_ptr.sub(VECTOR_SIZE) {
- let a = _mm256_load_si256(ptr as *const __m256i);
- let cmp = _mm256_cmpgt_epi8(v_flag, _mm256_sub_epi8(a, v_flag_below));
- write_mask!(cmp, ptr, start_ptr, start, fmt, self);
- ptr = ptr.add(VECTOR_SIZE);
- debug_assert!(start <= sub(ptr, start_ptr));
- }
- // res
- while ptr < end_ptr {
- debug_assert!(sub(end_ptr, ptr) < VECTOR_SIZE);
- let b = *ptr;
- let at = sub(ptr, start_ptr);
- if b.wrapping_sub(b'"') <= FLAG {
- escaping_bodies!(start, b, at, fmt, self);
- }
- ptr = ptr.add(1);
- debug_assert!(start <= sub(ptr, start_ptr));
- }
- fmt.write_str(str::from_utf8_unchecked(&self.bytes[start..]))?;
- Ok(())
- }
- }
- }
- /// Subtract `b` from `a` and return the difference. `a` should be greater than
- /// or equal to `b`.
- fn sub(a: *const u8, b: *const u8) -> usize {
- debug_assert!(a >= b);
- (a as usize) - (b as usize)
- }
- #[cfg(test)]
- mod tests {
- use super::*;
- #[test]
- fn test_escape() {
- let string_long = "foobarfoofoofoofoofoofoofoofoofoofoofoofoobarfoobarfoofoofoofoofoo";
- assert_eq!(escape("<&>").to_string(), "<&>");
- assert_eq!(escape("bla&").to_string(), "bla&");
- assert_eq!(escape("<foo").to_string(), "<foo");
- assert_eq!(escape("bla&h").to_string(), "bla&h");
- assert_eq!(escape(string_long).to_string(), string_long);
- assert_eq!(escape(&[string_long, "<"].join("")).to_string(), [string_long, "<"].join(""));
- assert_eq!(escape(unsafe { str::from_utf8_unchecked(&[b'<'; 124]) }).to_string(), ["<"; 124].join(""));
- assert_eq!(escape(&[string_long, "<", string_long, string_long].join("")).to_string(), [string_long, "<", string_long, string_long].join(""));
- }
- }
Add Comment
Please, Sign In to add comment