Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
- struct Buffer<M> {
- ref_count: AtomicUsize,
- ptr: *mut M,
- cap: u32,
- /// The number of messages in the buffer is stored in the lower 32 bits.
- /// The index of the next message to read is stored in the upper 32 bits.
- state: AtomicU64,
- dropped: AtomicUsize,
- }
- impl<M> Buffer<M> {
- /// Cap must be a power of two.
- fn new(cap: u32) -> Self {
- assert!(cap.is_power_of_two());
- let mut vec = Vec::with_capacity(cap as usize);
- let cap = vec.capacity() as u32;
- let ptr = vec.as_mut_ptr();
- std::mem::forget(vec);
- Self {
- ref_count: AtomicUsize::new(2),
- ptr: ptr,
- cap,
- //state: AtomicU64::new(make_state(!0u32, 0)), <- use to test overflow of idx
- state: AtomicU64::new(0),
- dropped: AtomicUsize::new(0),
- }
- }
- unsafe fn dealloc(&self) -> bool {
- if self.ref_count.fetch_sub(1, Ordering::Relaxed) == 1 {
- drop(Vec::from_raw_parts(self.ptr, 0, self.cap as usize));
- true
- } else {
- false
- }
- }
- }
- /// (next_index, num_msgs)
- fn split_state(state: u64) -> (u32, u32) {
- ((state >> 32) as u32, state as u32)
- }
- fn make_state(next_idx: u32, num_msgs: u32) -> u64 {
- let next_idx = u64::from(next_idx);
- let num_msgs = u64::from(num_msgs);
- (next_idx << 32) | num_msgs
- }
- pub struct Sender<M> {
- inner: *mut Buffer<M>,
- recv_thread: std::thread::Thread,
- }
- impl<M> Sender<M> {
- pub fn send(&mut self, msg: M) {
- unsafe {
- let cap = (*self.inner).cap;
- let state = (*self.inner).state.load(Ordering::Relaxed);
- let (next_read_idx, num_msgs) = split_state(state);
- if num_msgs >= cap {
- let nxt = (next_read_idx+1) % cap;
- let skip_msg = make_state(nxt, num_msgs-1);
- if (*self.inner).state.compare_and_swap(state, skip_msg, Ordering::SeqCst) == state {
- // The message was skipped.
- // If the compare failed, that just means someone else changed it, but in
- // this case that means they received a message, meaning we no longer need to
- // drop one.
- (*self.inner).dropped.fetch_add(1, Ordering::Relaxed);
- }
- }
- // Even if the receiver updates the state during our write, the sum
- // below is never changed by the receiver.
- let next_write_idx = (next_read_idx + num_msgs) % cap;
- let ptr = (*self.inner).ptr.offset(next_write_idx as isize);
- std::ptr::write(ptr, msg);
- // Since the number of messages is in the lower 32 bits, we can just increment it
- // without a CAS loop.
- //
- // Since num_msgs is only incremented in this method, we should never overflow
- // num_msgs. This is because of the num_msgs >= cap check above, and because
- // all access to send is unique, which is enforced by &mut self.
- (*self.inner).state.fetch_add(1, Ordering::Release);
- self.recv_thread.unpark();
- }
- }
- }
- impl<M> Drop for Sender<M> {
- fn drop(&mut self) {
- unsafe {
- if (*self.inner).dealloc() {
- drop(Box::from_raw(self.inner));
- }
- }
- }
- }
- unsafe impl<M> Send for Sender<M> {}
- unsafe impl<M> Sync for Sender<M> {}
- /// The Receiver is not Send since the Sender has a handle to the thread it was
- /// created in.
- pub struct Receiver<M> {
- inner: *mut Buffer<M>,
- }
- impl<M> Receiver<M> {
- pub fn recv(&mut self) -> M {
- loop {
- if let Some(m) = self.try_recv() {
- return m;
- }
- // There is no race condition here: If unpark
- // happens first, park will return immediately.
- // Hence there is no risk of a deadlock.
- std::thread::park();
- }
- }
- pub fn try_recv(&mut self) -> Option<M> {
- unsafe {
- let state = (*self.inner).state.load(Ordering::Acquire);
- let (next_read_idx, num_msgs) = split_state(state);
- if num_msgs == 0 {
- None
- } else {
- let cap = (*self.inner).cap;
- let ptr = (*self.inner).ptr.offset((next_read_idx % cap) as isize);
- let val = std::ptr::read(ptr);
- // This adds one to next_read_idx and subtracts one from the number
- // of messages.
- // Note that if next_idx overflows, this does not affect the
- // number of messages and as long as cap is a multiple of two,
- // it behaves correctly.
- (*self.inner).state.fetch_add((1u64 << 32) - 1, Ordering::Relaxed);
- Some(val)
- }
- }
- }
- pub fn num_dropped(&self) -> usize {
- unsafe {
- (*self.inner).dropped.load(Ordering::Relaxed)
- }
- }
- }
- impl<M> Drop for Receiver<M> {
- fn drop(&mut self) {
- unsafe {
- if (*self.inner).dealloc() {
- drop(Box::from_raw(self.inner));
- }
- }
- }
- }
- pub fn channel<M>(cap: u32) -> (Sender<M>, Receiver<M>) {
- let thread_handle = std::thread::current();
- let inner = Buffer::new(cap);
- let ptr = Box::into_raw(Box::new(inner));
- (
- Sender { inner: ptr, recv_thread: thread_handle, },
- Receiver { inner: ptr, },
- )
- }
- fn main() {
- const COUNT: u32 = 10000000;
- let (mut send, mut recv) = channel(1024 * 1024);
- let handle = std::thread::spawn(move || {
- let now = std::time::Instant::now();
- for i in 0..COUNT {
- send.send(i);
- }
- println!("Done. Used {:?} per send.", now.elapsed() / COUNT);
- });
- for _ in 0..COUNT {
- let _ = recv.recv();
- }
- handle.join().unwrap();
- println!("{} were dropped.", recv.num_dropped());
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement