SHARE
TWEET

Untitled

a guest Sep 18th, 2019 88 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
  2.  
  3. struct Buffer<M> {
  4.     ref_count: AtomicUsize,
  5.     ptr: *mut M,
  6.     cap: u32,
  7.     /// The number of messages in the buffer is stored in the lower 32 bits.
  8.     /// The index of the next message to read is stored in the upper 32 bits.
  9.     state: AtomicU64,
  10.     dropped: AtomicUsize,
  11. }
  12.  
  13. impl<M> Buffer<M> {
  14.     /// Cap must be a power of two.
  15.     fn new(cap: u32) -> Self {
  16.         let mut vec = Vec::with_capacity(cap as usize);
  17.         let cap = vec.capacity() as u32;
  18.         assert!(cap.is_power_of_two());
  19.         let ptr = vec.as_mut_ptr();
  20.         std::mem::forget(vec);
  21.         Self {
  22.             ref_count: AtomicUsize::new(2),
  23.             ptr: ptr,
  24.             cap,
  25.             //state: AtomicU64::new(make_state(!0u32, 0)), <- use to test overflow of idx
  26.             state: AtomicU64::new(0),
  27.             dropped: AtomicUsize::new(0),
  28.         }
  29.     }
  30.     unsafe fn dealloc(&self) -> bool {
  31.         if self.ref_count.fetch_sub(1, Ordering::Relaxed) == 1 {
  32.             drop(Vec::from_raw_parts(self.ptr, 0, self.cap as usize));
  33.             true
  34.         } else {
  35.             false
  36.         }
  37.     }
  38. }
  39.  
  40. /// (next_index, num_msgs)
  41. fn split_state(state: u64) -> (u32, u32) {
  42.     ((state >> 32) as u32, state as u32)
  43. }
  44. fn make_state(next_idx: u32, num_msgs: u32) -> u64 {
  45.     let next_idx = u64::from(next_idx);
  46.     let num_msgs = u64::from(num_msgs);
  47.     (next_idx << 32) | num_msgs
  48. }
  49.  
  50. pub struct Sender<M> {
  51.     inner: *mut Buffer<M>,
  52.     recv_thread: std::thread::Thread,
  53. }
  54. impl<M> Sender<M> {
  55.     pub fn send(&mut self, msg: M) {
  56.         unsafe {
  57.             let cap = (*self.inner).cap;
  58.             let state = (*self.inner).state.load(Ordering::Relaxed);
  59.             let (next_read_idx, num_msgs) = split_state(state);
  60.             if num_msgs >= cap {
  61.                 let nxt = (next_read_idx+1) % cap;
  62.                 let skip_msg = make_state(nxt, num_msgs-1);
  63.                 if (*self.inner).state.compare_and_swap(state, skip_msg, Ordering::SeqCst) == state {
  64.                     // The message was skipped.
  65.                     // If the compare failed, that just means someone else changed it, but in
  66.                     // this case that means they received a message, meaning we no longer need to
  67.                     // drop one.
  68.                     (*self.inner).dropped.fetch_add(1, Ordering::Relaxed);
  69.                 }
  70.             }
  71.             // Even if the receiver updates the state during our write, the sum
  72.             // below is never changed by the receiver.
  73.             let next_write_idx = (next_read_idx + num_msgs) % cap;
  74.             let ptr = (*self.inner).ptr.offset(next_write_idx as isize);
  75.             std::ptr::write(ptr, msg);
  76.             // Since the number of messages is in the lower 32 bits, we can just increment it
  77.             // without a CAS loop.
  78.             //
  79.             // Since num_msgs is only incremented in this method, we should never overflow
  80.             // num_msgs. This is because of the num_msgs >= cap check above, and because
  81.             // all access to send is unique, which is enforced by &mut self.
  82.             (*self.inner).state.fetch_add(1, Ordering::Release);
  83.             self.recv_thread.unpark();
  84.         }
  85.     }
  86. }
  87. impl<M> Drop for Sender<M> {
  88.     fn drop(&mut self) {
  89.         unsafe {
  90.             if (*self.inner).dealloc() {
  91.                 drop(Box::from_raw(self.inner));
  92.             }
  93.         }
  94.     }
  95. }
  96. unsafe impl<M> Send for Sender<M> {}
  97. unsafe impl<M> Sync for Sender<M> {}
  98.  
  99. /// The Receiver is not Send since the Sender has a handle to the thread it was
  100. /// created in.
  101. pub struct Receiver<M> {
  102.     inner: *mut Buffer<M>,
  103. }
  104. impl<M> Receiver<M> {
  105.     pub fn recv(&mut self) -> M {
  106.         loop {
  107.             if let Some(m) = self.try_recv() {
  108.                 return m;
  109.             }
  110.             // There is no race condition here: If unpark
  111.             // happens first, park will return immediately.
  112.             // Hence there is no risk of a deadlock.
  113.             std::thread::park();
  114.         }
  115.     }
  116.     pub fn try_recv(&mut self) -> Option<M> {
  117.         unsafe {
  118.             let state = (*self.inner).state.load(Ordering::Acquire);
  119.             let (next_read_idx, num_msgs) = split_state(state);
  120.             if num_msgs == 0 {
  121.                 None
  122.             } else {
  123.                 let cap = (*self.inner).cap;
  124.                 let ptr = (*self.inner).ptr.offset((next_read_idx % cap) as isize);
  125.                 let val = std::ptr::read(ptr);
  126.  
  127.                 // This adds one to next_read_idx and subtracts one from the number
  128.                 // of messages.
  129.                 // Note that if next_idx overflows, this does not affect the
  130.                 // number of messages and as long as cap is a multiple of two,
  131.                 // it behaves correctly.
  132.                 (*self.inner).state.fetch_add((1u64 << 32) - 1, Ordering::Relaxed);
  133.  
  134.                 Some(val)
  135.             }
  136.         }
  137.     }
  138.     pub fn num_dropped(&self) -> usize {
  139.         unsafe {
  140.             (*self.inner).dropped.load(Ordering::Relaxed)
  141.         }
  142.     }
  143. }
  144. impl<M> Drop for Receiver<M> {
  145.     fn drop(&mut self) {
  146.         unsafe {
  147.             if (*self.inner).dealloc() {
  148.                 drop(Box::from_raw(self.inner));
  149.             }
  150.         }
  151.     }
  152. }
  153.  
  154. pub fn channel<M>(cap: u32) -> (Sender<M>, Receiver<M>) {
  155.     let thread_handle = std::thread::current();
  156.     let inner = Buffer::new(cap);
  157.     let ptr = Box::into_raw(Box::new(inner));
  158.     (
  159.         Sender { inner: ptr, recv_thread: thread_handle, },
  160.         Receiver { inner: ptr, },
  161.     )
  162. }
  163.  
  164. fn main() {
  165.     const COUNT: u32 = 10000000;
  166.     let (mut send, mut recv) = channel(1024 * 1024);
  167.     let handle = std::thread::spawn(move || {
  168.         let now = std::time::Instant::now();
  169.         for i in 0..COUNT {
  170.             send.send(i);
  171.         }
  172.         println!("Done. Used {:?} per send.", now.elapsed() / COUNT);
  173.     });
  174.     for _ in 0..COUNT {
  175.         let _ = recv.recv();
  176.     }
  177.     handle.join().unwrap();
  178.     println!("{} were dropped.", recv.num_dropped());
  179. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top