Guest User

Untitled

a guest
Nov 16th, 2018
116
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.37 KB | None | 0 0
  1. use std::arch::x86_64::*;
  2. use std::fmt;
  3. use std::mem::size_of;
  4. use std::str;
  5.  
  6. const VECTOR_SIZE: usize = size_of::<__m256i>();
  7. const VECTOR_ALIGN: usize = VECTOR_SIZE - 1;
  8. const LOOP_SIZE: usize = 2 * VECTOR_SIZE;
  9. const FLAG: u8 = b'>' - b'"';
  10.  
  11. pub fn escape(s: &str) -> Escaped {
  12. Escaped {
  13. bytes: s.as_bytes(),
  14. }
  15. }
  16.  
  17. macro_rules! escaping_body {
  18. ($start:ident, $i:ident, $fmt:ident, $_self:ident, $quote:expr) => {{
  19. if $start < $i {
  20. $fmt.write_str(str::from_utf8_unchecked(&$_self.bytes[$start..$i]))?;
  21. }
  22. $fmt.write_str($quote)?;
  23. $start = $i + 1;
  24. }};
  25. }
  26.  
  27. macro_rules! escaping_bodies {
  28. ($start:ident, $b: ident, $i:ident, $fmt:ident, $_self:ident) => {{
  29. match $b {
  30. b'<' => escaping_body!($start, $i, $fmt, $_self, "<"),
  31. b'>' => escaping_body!($start, $i, $fmt, $_self, ">"),
  32. b'&' => escaping_body!($start, $i, $fmt, $_self, "&"),
  33. b'"' => escaping_body!($start, $i, $fmt, $_self, """),
  34. b'\'' => escaping_body!($start, $i, $fmt, $_self, "&#x27;"),
  35. b'/' => escaping_body!($start, $i, $fmt, $_self, "&#x2f;"),
  36. _ => (),
  37. }
  38. }};
  39. }
  40.  
  41. macro_rules! write_mask {
  42. ($cmp: ident, $ptr: ident, $start_ptr: ident, $start: ident, $fmt: ident, $_self:ident) => {{
  43. let mut mask = _mm256_movemask_epi8($cmp);
  44. if mask != 0 {
  45. let at = sub($ptr, $start_ptr);
  46. let mut pos = mask.trailing_zeros() as usize;
  47. loop {
  48. let b = *$ptr.add(pos);
  49. let at = at + pos;
  50.  
  51. escaping_bodies!($start, b, at, $fmt, $_self);
  52.  
  53. debug_assert!(32 > pos && 0 <= pos);
  54. mask ^= 1 << pos;
  55. if mask == 0 {
  56. break;
  57. }
  58.  
  59. pos = mask.trailing_zeros() as usize;
  60. }
  61.  
  62. debug_assert!(at == sub($ptr, $start_ptr));
  63. }
  64. }};
  65. }
  66.  
  67. pub struct Escaped<'a> {
  68. bytes: &'a [u8],
  69. }
  70.  
  71. impl<'a> fmt::Display for Escaped<'a> {
  72. fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
  73. let len = self.bytes.len();
  74. let mut start = 0;
  75.  
  76. unsafe {
  77. if len < VECTOR_SIZE {
  78. for (i, b) in self.bytes.iter().enumerate() {
  79. if b.wrapping_sub(b'"') <= FLAG {
  80. let b = *b;
  81. escaping_bodies!(start, b, i, fmt, self);
  82. }
  83. }
  84. fmt.write_str(str::from_utf8_unchecked(&self.bytes[start..]))?;
  85. return Ok(());
  86. }
  87.  
  88. let v_flag = _mm256_set1_epi8((FLAG + 1) as i8);
  89. let v_flag_below = _mm256_set1_epi8(b'"' as i8);
  90.  
  91. let start_ptr = self.bytes.as_ptr();
  92. let end_ptr = self.bytes[len..].as_ptr();
  93. let mut ptr = start_ptr;
  94. let align = VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN);
  95.  
  96. // Align
  97. for i in 0..align {
  98. let b = *ptr;
  99. if b.wrapping_sub(b'"') <= FLAG {
  100. escaping_bodies!(start, b, i, fmt, self);
  101. }
  102.  
  103. ptr = ptr.add(1);
  104. }
  105.  
  106. debug_assert!(start <= sub(ptr, start_ptr));
  107. debug_assert!(ptr > start_ptr && end_ptr.sub(VECTOR_SIZE) >= start_ptr);
  108.  
  109. // body
  110. if LOOP_SIZE <= len {
  111. while ptr <= end_ptr.sub(LOOP_SIZE) {
  112. debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE);
  113.  
  114. let a = _mm256_load_si256(ptr as *const __m256i);
  115. let b = _mm256_load_si256(ptr.add(VECTOR_SIZE) as *const __m256i);
  116. let cmp_a = _mm256_cmpgt_epi8(v_flag, _mm256_sub_epi8(a, v_flag_below));
  117. let cmp_b = _mm256_cmpgt_epi8(v_flag, _mm256_sub_epi8(b, v_flag_below));
  118. let mask = _mm256_movemask_epi8(_mm256_or_si256(cmp_a, cmp_b));
  119.  
  120. if mask != 0 {
  121. debug_assert!(start <= sub(ptr, start_ptr));
  122. write_mask!(cmp_a, ptr, start_ptr, start, fmt, self);
  123.  
  124. let ptr = ptr.add(VECTOR_SIZE);
  125. debug_assert!(ptr <= end_ptr);
  126.  
  127. write_mask!(cmp_b, ptr, start_ptr, start, fmt, self);
  128. }
  129. ptr = ptr.add(LOOP_SIZE);
  130. debug_assert!(start <= sub(ptr, start_ptr));
  131. }
  132. }
  133.  
  134. // n - 1 loop
  135. while ptr <= end_ptr.sub(VECTOR_SIZE) {
  136. let a = _mm256_load_si256(ptr as *const __m256i);
  137. let cmp = _mm256_cmpgt_epi8(v_flag, _mm256_sub_epi8(a, v_flag_below));
  138. write_mask!(cmp, ptr, start_ptr, start, fmt, self);
  139.  
  140. ptr = ptr.add(VECTOR_SIZE);
  141. debug_assert!(start <= sub(ptr, start_ptr));
  142. }
  143.  
  144. // res
  145. while ptr < end_ptr {
  146. debug_assert!(sub(end_ptr, ptr) < VECTOR_SIZE);
  147. let b = *ptr;
  148. let at = sub(ptr, start_ptr);
  149. if b.wrapping_sub(b'"') <= FLAG {
  150. escaping_bodies!(start, b, at, fmt, self);
  151. }
  152.  
  153. ptr = ptr.add(1);
  154. debug_assert!(start <= sub(ptr, start_ptr));
  155. }
  156.  
  157. fmt.write_str(str::from_utf8_unchecked(&self.bytes[start..]))?;
  158. Ok(())
  159. }
  160. }
  161. }
  162.  
  163. /// Subtract `b` from `a` and return the difference. `a` should be greater than
  164. /// or equal to `b`.
  165. fn sub(a: *const u8, b: *const u8) -> usize {
  166. debug_assert!(a >= b);
  167. (a as usize) - (b as usize)
  168. }
  169.  
  170.  
  171. #[cfg(test)]
  172. mod tests {
  173. use super::*;
  174.  
  175. #[test]
  176. fn test_escape() {
  177. let string_long = "foobarfoofoofoofoofoofoofoofoofoofoofoofoobarfoobarfoofoofoofoofoo";
  178. assert_eq!(escape("<&>").to_string(), "<&>");
  179. assert_eq!(escape("bla&").to_string(), "bla&");
  180. assert_eq!(escape("<foo").to_string(), "<foo");
  181. assert_eq!(escape("bla&h").to_string(), "bla&h");
  182. assert_eq!(escape(string_long).to_string(), string_long);
  183. assert_eq!(escape(&[string_long, "<"].join("")).to_string(), [string_long, "<"].join(""));
  184. assert_eq!(escape(unsafe { str::from_utf8_unchecked(&[b'<'; 124]) }).to_string(), ["<"; 124].join(""));
  185. assert_eq!(escape(&[string_long, "<", string_long, string_long].join("")).to_string(), [string_long, "<", string_long, string_long].join(""));
  186. }
  187. }
Add Comment
Please, Sign In to add comment