1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::cell::RefCell;
15 use std::mem::MaybeUninit;
16 use std::ptr;
17 use std::ptr::NonNull;
18 use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
19 use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize};
20 use std::task::Poll::{Pending, Ready};
21 use std::task::{Context, Poll};
22 
23 use crate::sync::atomic_waker::AtomicWaker;
24 use crate::sync::error::{RecvError, SendError, TryRecvError};
25 use crate::sync::mpsc::Container;
26 
27 /// The capacity of a block.
28 const CAPACITY: usize = 32;
29 /// The offset of the index.
30 const INDEX_SHIFT: usize = 1;
31 /// The flag marks that Array is closed.
32 const CLOSED: usize = 0b01;
33 
34 pub(crate) struct Node<T> {
35     has_value: AtomicBool,
36     value: RefCell<MaybeUninit<T>>,
37 }
38 
39 struct Block<T> {
40     data: [Node<T>; CAPACITY],
41     next: AtomicPtr<Block<T>>,
42 }
43 
44 impl<T> Block<T> {
new() -> Block<T>45     fn new() -> Block<T> {
46         Block {
47             data: unsafe { MaybeUninit::zeroed().assume_init() },
48             next: AtomicPtr::new(ptr::null_mut()),
49         }
50     }
51 
reclaim(&self)52     fn reclaim(&self) {
53         self.next.store(ptr::null_mut(), Release);
54     }
55 
try_insert(&self, ptr: *mut Block<T>) -> Result<(), *mut Block<T>>56     fn try_insert(&self, ptr: *mut Block<T>) -> Result<(), *mut Block<T>> {
57         match self
58             .next
59             .compare_exchange(ptr::null_mut(), ptr, AcqRel, Acquire)
60         {
61             Ok(_) => Ok(()),
62             Err(new_ptr) => Err(new_ptr),
63         }
64     }
65 
insert(&self, ptr: *mut Block<T>)66     fn insert(&self, ptr: *mut Block<T>) {
67         let mut curr = self;
68         // The number of cycles is limited. Recycling blocks is to avoid frequent
69         // creation and destruction, but trying too many times may consume more
70         // resources. Every block should stop trying after failing to insert for
71         // a certain times.
72         for _ in 0..5 {
73             match curr.try_insert(ptr) {
74                 Ok(_) => return,
75                 Err(next) => {
76                     // the sender and receiver is synced by the flag `has_value`,
77                     // therefore this next ptr is guaranteed to be non-null
78                     curr = unsafe { next.as_ref().unwrap() };
79                 }
80             }
81         }
82         unsafe {
83             drop(Box::from_raw(ptr));
84         }
85     }
86 }
87 
88 struct Head<T> {
89     block: NonNull<Block<T>>,
90     index: usize,
91 }
92 
93 struct Tail<T> {
94     block: AtomicPtr<Block<T>>,
95     index: AtomicUsize,
96 }
97 
98 /// Unbounded lockless queue.
99 pub(crate) struct Queue<T> {
100     head: RefCell<Head<T>>,
101     tail: Tail<T>,
102     rx_waker: AtomicWaker,
103 }
104 
105 unsafe impl<T: Send> Send for Queue<T> {}
106 unsafe impl<T: Send> Sync for Queue<T> {}
107 
108 impl<T> Queue<T> {
new() -> Queue<T>109     pub(crate) fn new() -> Queue<T> {
110         let block = Box::new(Block::new());
111         let block_ptr = Box::into_raw(block);
112         Queue {
113             head: RefCell::new(Head {
114                 // block_ptr is non-null
115                 block: NonNull::new(block_ptr).unwrap(),
116                 index: 0,
117             }),
118             tail: Tail {
119                 block: AtomicPtr::new(block_ptr),
120                 index: AtomicUsize::new(0),
121             },
122             rx_waker: AtomicWaker::new(),
123         }
124     }
125 
send_inner( &self, index: usize, block: &Block<T>, new_block: Option<Box<Block<T>>>, value: T, )126     fn send_inner(
127         &self,
128         index: usize,
129         block: &Block<T>,
130         new_block: Option<Box<Block<T>>>,
131         value: T,
132     ) {
133         if index + 1 == CAPACITY {
134             // if the index is the last one, new block has been set above
135             let new_block_ptr = Box::into_raw(new_block.unwrap());
136             block.insert(new_block_ptr);
137             let next_block = block.next.load(Acquire);
138             self.tail.block.store(next_block, Release);
139             self.tail.index.fetch_add(1 << INDEX_SHIFT, Release);
140         }
141         // the index is bounded by Capacity
142         let node = block.data.get(index).unwrap();
143         unsafe {
144             node.value.as_ptr().write(MaybeUninit::new(value));
145         }
146         node.has_value.store(true, Release);
147         self.rx_waker.wake();
148     }
149 
send(&self, value: T) -> Result<(), SendError<T>>150     pub(crate) fn send(&self, value: T) -> Result<(), SendError<T>> {
151         let mut tail = self.tail.index.load(Acquire);
152         let mut block_ptr = self.tail.block.load(Acquire);
153         loop {
154             let mut new_block = None;
155             if tail & CLOSED == CLOSED {
156                 return Err(SendError(value));
157             }
158             let index = (tail >> INDEX_SHIFT) % (CAPACITY + 1);
159             if index == CAPACITY {
160                 tail = self.tail.index.load(Acquire);
161                 block_ptr = self.tail.block.load(Acquire);
162                 continue;
163             }
164             let block = unsafe { &*block_ptr };
165             if index + 1 == CAPACITY && new_block.is_none() {
166                 new_block = Some(Box::new(Block::<T>::new()));
167             }
168             match self
169                 .tail
170                 .index
171                 .compare_exchange(tail, tail + (1 << INDEX_SHIFT), AcqRel, Acquire)
172             {
173                 Ok(_) => {
174                     self.send_inner(index, block, new_block, value);
175                     return Ok(());
176                 }
177                 Err(_) => {
178                     tail = self.tail.index.load(Acquire);
179                     block_ptr = self.tail.block.load(Acquire);
180                 }
181             }
182         }
183     }
184 
try_recv(&self) -> Result<T, TryRecvError>185     pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> {
186         let mut head = self.head.borrow_mut();
187         let head_index = head.index;
188         let block_ptr = head.block.as_ptr();
189         let block = unsafe { &*block_ptr };
190         let index = head_index % (CAPACITY + 1);
191         // index is guaranteed to not equal to capacity because of the wrapping_add by 2
192         // down below
193         let node = block.data.get(index).unwrap();
194         // Check whether the node is ready to read.
195         if node.has_value.swap(false, Acquire) {
196             let value = unsafe { node.value.as_ptr().read().assume_init() };
197             if index + 1 == CAPACITY {
198                 // next is initialized during block creation
199                 head.block = NonNull::new(block.next.load(Acquire)).unwrap();
200                 block.reclaim();
201                 unsafe { (*self.tail.block.load(Acquire)).insert(block_ptr) };
202                 // When the nodes in a block are full, the last index is reserved as a buffer
203                 // for `Send` to synchronize two atomic operations.
204                 head.index = head_index.wrapping_add(2);
205             } else {
206                 head.index = head_index.wrapping_add(1);
207             }
208             Ok(value)
209         } else {
210             let tail_index = self.tail.index.load(Acquire);
211             if tail_index & CLOSED == CLOSED {
212                 Err(TryRecvError::Closed)
213             } else {
214                 Err(TryRecvError::Empty)
215             }
216         }
217     }
218 
poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>>219     pub(crate) fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
220         match self.try_recv() {
221             Ok(val) => return Ready(Ok(val)),
222             Err(TryRecvError::Closed) => return Ready(Err(RecvError)),
223             _ => {}
224         }
225 
226         self.rx_waker.register_by_ref(cx.waker());
227 
228         match self.try_recv() {
229             Ok(val) => Ready(Ok(val)),
230             Err(TryRecvError::Closed) => Ready(Err(RecvError)),
231             Err(TryRecvError::Empty) => Pending,
232         }
233     }
234 }
235 
236 impl<T> Container for Queue<T> {
close(&self)237     fn close(&self) {
238         self.tail.index.fetch_or(CLOSED, Release);
239         self.rx_waker.wake();
240     }
241 
is_close(&self) -> bool242     fn is_close(&self) -> bool {
243         self.tail.index.load(Acquire) & CLOSED == CLOSED
244     }
245 
len(&self) -> usize246     fn len(&self) -> usize {
247         let head = self.head.borrow().index;
248         let mut tail = self.tail.index.load(Acquire) >> INDEX_SHIFT;
249         if tail % (CAPACITY + 1) == CAPACITY {
250             tail = tail.wrapping_add(1);
251         }
252         let head_redundant = head / (CAPACITY + 1);
253         let tail_redundant = tail / (CAPACITY + 1);
254         tail - head - (tail_redundant - head_redundant)
255     }
256 }
257 
258 impl<T> Drop for Queue<T> {
drop(&mut self)259     fn drop(&mut self) {
260         let head = self.head.borrow_mut();
261         let mut head_index = head.index;
262         let tail_index = self.tail.index.load(Acquire) >> INDEX_SHIFT;
263         let mut block_ptr = head.block.as_ptr();
264         while head_index < tail_index {
265             let index = head_index % (CAPACITY + 1);
266             unsafe {
267                 if index == CAPACITY {
268                     let next_node_ptr = (*block_ptr).next.load(Acquire);
269                     drop(Box::from_raw(block_ptr));
270                     block_ptr = next_node_ptr;
271                 } else {
272                     // index is bounded by capacity
273                     let node = (*block_ptr).data.get_mut(index).unwrap();
274                     node.value.get_mut().as_mut_ptr().drop_in_place();
275                 }
276             }
277             head_index = head_index.wrapping_add(1);
278         }
279         while !block_ptr.is_null() {
280             unsafe {
281                 let next_node_ptr = (*block_ptr).next.load(Acquire);
282                 drop(Box::from_raw(block_ptr));
283                 block_ptr = next_node_ptr;
284             }
285         }
286     }
287 }
288