Advertisement
Guest User

Untitled

a guest
Sep 18th, 2019
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.90 KB | None | 0 0
  1. use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
  2.  
  3. struct Buffer<M> {
  4. ref_count: AtomicUsize,
  5. ptr: *mut M,
  6. cap: u32,
  7. /// The number of messages in the buffer is stored in the lower 32 bits.
  8. /// The index of the next message to read is stored in the upper 32 bits.
  9. state: AtomicU64,
  10. dropped: AtomicUsize,
  11. }
  12.  
  13. impl<M> Buffer<M> {
  14. /// Cap must be a power of two.
  15. fn new(cap: u32) -> Self {
  16. let mut vec = Vec::with_capacity(cap as usize);
  17. let cap = vec.capacity() as u32;
  18. assert!(cap.is_power_of_two());
  19. let ptr = vec.as_mut_ptr();
  20. std::mem::forget(vec);
  21. Self {
  22. ref_count: AtomicUsize::new(2),
  23. ptr: ptr,
  24. cap,
  25. //state: AtomicU64::new(make_state(!0u32, 0)), <- use to test overflow of idx
  26. state: AtomicU64::new(0),
  27. dropped: AtomicUsize::new(0),
  28. }
  29. }
  30. unsafe fn dealloc(&self) -> bool {
  31. if self.ref_count.fetch_sub(1, Ordering::Relaxed) == 1 {
  32. drop(Vec::from_raw_parts(self.ptr, 0, self.cap as usize));
  33. true
  34. } else {
  35. false
  36. }
  37. }
  38. }
  39.  
  40. /// (next_index, num_msgs)
  41. fn split_state(state: u64) -> (u32, u32) {
  42. ((state >> 32) as u32, state as u32)
  43. }
  44. fn make_state(next_idx: u32, num_msgs: u32) -> u64 {
  45. let next_idx = u64::from(next_idx);
  46. let num_msgs = u64::from(num_msgs);
  47. (next_idx << 32) | num_msgs
  48. }
  49.  
  50. pub struct Sender<M> {
  51. inner: *mut Buffer<M>,
  52. recv_thread: std::thread::Thread,
  53. }
  54. impl<M> Sender<M> {
  55. pub fn send(&mut self, msg: M) {
  56. unsafe {
  57. let cap = (*self.inner).cap;
  58. let state = (*self.inner).state.load(Ordering::Relaxed);
  59. let (next_read_idx, num_msgs) = split_state(state);
  60. if num_msgs >= cap {
  61. let nxt = (next_read_idx+1) % cap;
  62. let skip_msg = make_state(nxt, num_msgs-1);
  63. if (*self.inner).state.compare_and_swap(state, skip_msg, Ordering::SeqCst) == state {
  64. // The message was skipped.
  65. // If the compare failed, that just means someone else changed it, but in
  66. // this case that means they received a message, meaning we no longer need to
  67. // drop one.
  68. (*self.inner).dropped.fetch_add(1, Ordering::Relaxed);
  69. }
  70. }
  71. // Even if the receiver updates the state during our write, the sum
  72. // below is never changed by the receiver.
  73. let next_write_idx = (next_read_idx + num_msgs) % cap;
  74. let ptr = (*self.inner).ptr.offset(next_write_idx as isize);
  75. std::ptr::write(ptr, msg);
  76. // Since the number of messages is in the lower 32 bits, we can just increment it
  77. // without a CAS loop.
  78. //
  79. // Since num_msgs is only incremented in this method, we should never overflow
  80. // num_msgs. This is because of the num_msgs >= cap check above, and because
  81. // all access to send is unique, which is enforced by &mut self.
  82. (*self.inner).state.fetch_add(1, Ordering::Release);
  83. self.recv_thread.unpark();
  84. }
  85. }
  86. }
  87. impl<M> Drop for Sender<M> {
  88. fn drop(&mut self) {
  89. unsafe {
  90. if (*self.inner).dealloc() {
  91. drop(Box::from_raw(self.inner));
  92. }
  93. }
  94. }
  95. }
  96. unsafe impl<M> Send for Sender<M> {}
  97. unsafe impl<M> Sync for Sender<M> {}
  98.  
  99. /// The Receiver is not Send since the Sender has a handle to the thread it was
  100. /// created in.
  101. pub struct Receiver<M> {
  102. inner: *mut Buffer<M>,
  103. }
  104. impl<M> Receiver<M> {
  105. pub fn recv(&mut self) -> M {
  106. loop {
  107. if let Some(m) = self.try_recv() {
  108. return m;
  109. }
  110. // There is no race condition here: If unpark
  111. // happens first, park will return immediately.
  112. // Hence there is no risk of a deadlock.
  113. std::thread::park();
  114. }
  115. }
  116. pub fn try_recv(&mut self) -> Option<M> {
  117. unsafe {
  118. let state = (*self.inner).state.load(Ordering::Acquire);
  119. let (next_read_idx, num_msgs) = split_state(state);
  120. if num_msgs == 0 {
  121. None
  122. } else {
  123. let cap = (*self.inner).cap;
  124. let ptr = (*self.inner).ptr.offset((next_read_idx % cap) as isize);
  125. let val = std::ptr::read(ptr);
  126.  
  127. // This adds one to next_read_idx and subtracts one from the number
  128. // of messages.
  129. // Note that if next_idx overflows, this does not affect the
  130. // number of messages and as long as cap is a multiple of two,
  131. // it behaves correctly.
  132. (*self.inner).state.fetch_add((1u64 << 32) - 1, Ordering::Relaxed);
  133.  
  134. Some(val)
  135. }
  136. }
  137. }
  138. pub fn num_dropped(&self) -> usize {
  139. unsafe {
  140. (*self.inner).dropped.load(Ordering::Relaxed)
  141. }
  142. }
  143. }
  144. impl<M> Drop for Receiver<M> {
  145. fn drop(&mut self) {
  146. unsafe {
  147. if (*self.inner).dealloc() {
  148. drop(Box::from_raw(self.inner));
  149. }
  150. }
  151. }
  152. }
  153.  
  154. pub fn channel<M>(cap: u32) -> (Sender<M>, Receiver<M>) {
  155. let thread_handle = std::thread::current();
  156. let inner = Buffer::new(cap);
  157. let ptr = Box::into_raw(Box::new(inner));
  158. (
  159. Sender { inner: ptr, recv_thread: thread_handle, },
  160. Receiver { inner: ptr, },
  161. )
  162. }
  163.  
  164. fn main() {
  165. const COUNT: u32 = 10000000;
  166. let (mut send, mut recv) = channel(1024 * 1024);
  167. let handle = std::thread::spawn(move || {
  168. let now = std::time::Instant::now();
  169. for i in 0..COUNT {
  170. send.send(i);
  171. }
  172. println!("Done. Used {:?} per send.", now.elapsed() / COUNT);
  173. });
  174. for _ in 0..COUNT {
  175. let _ = recv.recv();
  176. }
  177. handle.join().unwrap();
  178. println!("{} were dropped.", recv.num_dropped());
  179. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement