1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 //! Watch channel
15 
16 use std::fmt::{Debug, Formatter};
17 use std::ops::Deref;
18 use std::sync::atomic::AtomicUsize;
19 use std::sync::atomic::Ordering::{Acquire, Release};
20 use std::sync::{Arc, RwLock, RwLockReadGuard};
21 use std::task::Poll::{Pending, Ready};
22 use std::task::{Context, Poll};
23 
24 use crate::futures::poll_fn;
25 use crate::sync::error::{RecvError, SendError};
26 use crate::sync::wake_list::WakerList;
27 
28 /// The least significant bit that marks the version of channel.
29 const VERSION_SHIFT: usize = 1;
30 /// The flag marks that channel is closed.
31 const CLOSED: usize = 1;
32 
33 /// Creates a new watch channel with a `Sender` and `Receiver` handle pair.
34 ///
35 /// The value sent by the `Sender` can be seen by all receivers, but only the
36 /// last value sent by `Sender` is visible to the `Receiver`.
37 ///
38 /// # Examples
39 ///
40 /// ```
41 /// use ylong_runtime::sync::watch;
42 /// async fn io_func() {
43 ///     let (tx, mut rx) = watch::channel(1);
44 ///     ylong_runtime::spawn(async move {
45 ///         let _ = rx.notified().await;
46 ///         assert_eq!(*rx.borrow(), 2);
47 ///     });
48 ///
49 ///     let _ = tx.send(2);
50 /// }
51 /// ```
channel<T>(value: T) -> (Sender<T>, Receiver<T>)52 pub fn channel<T>(value: T) -> (Sender<T>, Receiver<T>) {
53     let channel = Arc::new(Channel::new(value));
54     let tx = Sender {
55         channel: channel.clone(),
56     };
57     let rx = Receiver {
58         channel,
59         version: 0,
60     };
61     (tx, rx)
62 }
63 
64 /// The sender of watch channel.
65 /// A [`Sender`] and [`Receiver`] handle pair is created by the [`channel`]
66 /// function.
67 ///
68 /// # Examples
69 ///
70 /// ```
71 /// use ylong_runtime::sync::watch;
72 /// async fn io_func() {
73 ///     let (tx, mut rx) = watch::channel(1);
74 ///     assert_eq!(tx.receiver_count(), 1);
75 ///     ylong_runtime::spawn(async move {
76 ///         let _ = rx.notified().await;
77 ///         assert_eq!(*rx.borrow(), 2);
78 ///     });
79 ///
80 ///     let _ = tx.send(2);
81 /// }
82 /// ```
83 #[derive(Debug)]
84 pub struct Sender<T> {
85     channel: Arc<Channel<T>>,
86 }
87 
88 impl<T> Sender<T> {
89     /// Sends values to the associated [`Receiver`].
90     ///
91     /// An error containing the sent value would be returned if all receivers
92     /// are dropped.
93     ///
94     /// # Examples
95     ///
96     /// ```
97     /// use ylong_runtime::sync::watch;
98     /// async fn io_func() {
99     ///     let (tx, mut rx) = watch::channel(1);
100     ///     ylong_runtime::spawn(async move {
101     ///         let _ = rx.notified().await;
102     ///         assert_eq!(*rx.borrow(), 2);
103     ///     });
104     ///
105     ///     let _ = tx.send(2);
106     /// }
107     /// ```
send(&self, value: T) -> Result<(), SendError<T>>108     pub fn send(&self, value: T) -> Result<(), SendError<T>> {
109         if self.channel.rx_cnt.load(Acquire) == 0 {
110             return Err(SendError(value));
111         }
112         let mut lock = self.channel.value.write().unwrap();
113         *lock = value;
114         self.channel.state.version_update();
115         drop(lock);
116         self.channel.waker_list.notify_all();
117         Ok(())
118     }
119 
120     /// Creates a new [`Receiver`] associated with oneself.
121     ///
122     /// The newly created receiver will mark all the values sent before as seen.
123     ///
124     /// This method can create a new receiver when there is no receiver
125     /// available.
126     ///
127     /// # Examples
128     ///
129     /// ```
130     /// use ylong_runtime::sync::watch;
131     /// async fn io_func() {
132     ///     let (tx, mut rx) = watch::channel(1);
133     ///     let mut rx2 = tx.subscribe();
134     ///     assert_eq!(*rx.borrow(), 1);
135     ///     assert_eq!(*rx2.borrow(), 1);
136     ///     let _ = tx.send(2);
137     ///     assert_eq!(*rx.borrow(), 2);
138     ///     assert_eq!(*rx2.borrow(), 2);
139     /// }
140     /// ```
subscribe(&self) -> Receiver<T>141     pub fn subscribe(&self) -> Receiver<T> {
142         let (value_version, _) = self.channel.state.load();
143         self.channel.rx_cnt.fetch_add(1, Release);
144         Receiver {
145             channel: self.channel.clone(),
146             version: value_version,
147         }
148     }
149 
150     /// Gets the number of receivers associated with oneself.
151     ///
152     /// # Examples
153     ///
154     /// ```
155     /// use ylong_runtime::sync::watch;
156     /// async fn io_func() {
157     ///     let (tx, rx) = watch::channel(1);
158     ///     assert_eq!(tx.receiver_count(), 1);
159     ///     let rx2 = tx.subscribe();
160     ///     assert_eq!(tx.receiver_count(), 2);
161     ///     let rx3 = rx.clone();
162     ///     assert_eq!(tx.receiver_count(), 3);
163     /// }
164     /// ```
receiver_count(&self) -> usize165     pub fn receiver_count(&self) -> usize {
166         self.channel.rx_cnt.load(Acquire)
167     }
168 }
169 
170 impl<T> Drop for Sender<T> {
drop(&mut self)171     fn drop(&mut self) {
172         self.channel.close();
173     }
174 }
175 
176 /// Reference to the inner value.
177 ///
178 /// This reference will hold a read lock on the internal value, so holding this
179 /// reference will block the sender from sending data. When the watch channel
180 /// runs in an environment that allows !Send futures, you need to ensure that
181 /// the reference is not held across an. wait point to avoid deadlocks.
182 ///
183 /// The priority policy of RwLock is consistent with the `std::RwLock`.
184 ///
185 /// # Examples
186 ///
187 /// ```
188 /// use ylong_runtime::sync::watch;
189 /// async fn io_func() {
190 ///     let (tx, mut rx) = watch::channel(1);
191 ///     let v1 = rx.borrow();
192 ///     assert_eq!(*v1, 1);
193 ///     assert!(!v1.is_notified());
194 ///     drop(v1);
195 ///
196 ///     let _ = tx.send(2);
197 ///     let v2 = rx.borrow_notify();
198 ///     assert_eq!(*v2, 2);
199 ///     assert!(v2.is_notified());
200 ///     drop(v2);
201 ///
202 ///     let v3 = rx.borrow_notify();
203 ///     assert_eq!(*v3, 2);
204 ///     assert!(!v3.is_notified());
205 /// }
206 /// ```
207 #[derive(Debug)]
208 pub struct ValueRef<'a, T> {
209     value: RwLockReadGuard<'a, T>,
210     is_notified: bool,
211 }
212 
213 impl<'a, T> ValueRef<'a, T> {
new(value: RwLockReadGuard<'a, T>, is_notified: bool) -> ValueRef<'a, T>214     fn new(value: RwLockReadGuard<'a, T>, is_notified: bool) -> ValueRef<'a, T> {
215         ValueRef { value, is_notified }
216     }
217 
218     /// Check if the borrowed value has been marked as seen.
219     ///
220     /// # Examples
221     ///
222     /// ```
223     /// use ylong_runtime::sync::watch;
224     /// async fn io_func() {
225     ///     let (tx, mut rx) = watch::channel(1);
226     ///     let v1 = rx.borrow();
227     ///     assert_eq!(*v1, 1);
228     ///     assert!(!v1.is_notified());
229     ///     drop(v1);
230     ///
231     ///     let _ = tx.send(2);
232     ///     let v2 = rx.borrow_notify();
233     ///     assert_eq!(*v2, 2);
234     ///     assert!(v2.is_notified());
235     ///     drop(v2);
236     ///
237     ///     let v3 = rx.borrow_notify();
238     ///     assert_eq!(*v3, 2);
239     ///     assert!(!v3.is_notified());
240     /// }
241     /// ```
is_notified(&self) -> bool242     pub fn is_notified(&self) -> bool {
243         self.is_notified
244     }
245 }
246 
247 impl<T> Deref for ValueRef<'_, T> {
248     type Target = T;
249 
deref(&self) -> &Self::Target250     fn deref(&self) -> &Self::Target {
251         self.value.deref()
252     }
253 }
254 
255 /// The receiver of watch channel.
256 /// A [`Sender`] and [`Receiver`] handle pair is created by the [`channel`]
257 /// function.
258 ///
259 /// # Examples
260 ///
261 /// ```
262 /// use ylong_runtime::sync::watch;
263 /// async fn io_func() {
264 ///     let (tx, mut rx) = watch::channel(1);
265 ///     ylong_runtime::spawn(async move {
266 ///         let _ = rx.notified().await;
267 ///         assert_eq!(*rx.borrow(), 2);
268 ///     });
269 ///
270 ///     let _ = tx.send(2);
271 /// }
272 /// ```
273 #[derive(Debug)]
274 pub struct Receiver<T> {
275     channel: Arc<Channel<T>>,
276     version: usize,
277 }
278 
279 impl<T> Receiver<T> {
280     /// Check if [`Receiver`] has been notified of a new value that has not been
281     /// marked as seen.
282     ///
283     /// An error would be returned if the channel is closed.
284     ///
285     /// # Examples
286     ///
287     /// ```
288     /// use ylong_runtime::sync::watch;
289     /// async fn io_func() {
290     ///     let (tx, mut rx) = watch::channel(1);
291     ///     assert_eq!(rx.is_notified(), Ok(false));
292     ///
293     ///     let _ = tx.send(2);
294     ///     assert_eq!(*rx.borrow(), 2);
295     ///     assert_eq!(rx.is_notified(), Ok(true));
296     ///
297     ///     assert_eq!(*rx.borrow_notify(), 2);
298     ///     assert_eq!(rx.is_notified(), Ok(false));
299     ///
300     ///     drop(tx);
301     ///     assert!(rx.is_notified().is_err());
302     /// }
303     /// ```
is_notified(&self) -> Result<bool, RecvError>304     pub fn is_notified(&self) -> Result<bool, RecvError> {
305         let (value_version, is_closed) = self.channel.state.load();
306         if is_closed {
307             return Err(RecvError);
308         }
309         Ok(self.version != value_version)
310     }
311 
try_notified(&mut self) -> Option<Result<(), RecvError>>312     pub(crate) fn try_notified(&mut self) -> Option<Result<(), RecvError>> {
313         let (value_version, is_closed) = self.channel.state.load();
314         if self.version != value_version {
315             self.version = value_version;
316             return Some(Ok(()));
317         }
318 
319         if is_closed {
320             return Some(Err(RecvError));
321         }
322 
323         None
324     }
325 
326     /// Polls to receive a notification from the associated [`Sender`].
327     ///
328     /// When the sender has not yet sent a new message and the message in
329     /// channel has seen, calling this method will return pending, and the
330     /// waker from the Context will receive a wakeup when the message
331     /// arrives or when the channel is closed. Multiple calls to this
332     /// method, only the waker from the last call will receive a wakeup.
333     ///
334     /// # Return value
335     /// * `Poll::Pending` if no new messages comes, but the channel is not
336     ///   closed.
337     /// * `Poll::Ready(Ok(T))` if receiving a new value or the value in channel
338     ///   has not yet seen.
339     /// * `Poll::Ready(Err(RecvError))` The sender has been dropped or the
340     ///   channel is closed.
341     ///
342     /// # Examples
343     ///
344     /// ```
345     /// use ylong_runtime::futures::poll_fn;
346     /// use ylong_runtime::sync::watch;
347     /// async fn io_func() {
348     ///     let (tx, mut rx) = watch::channel(1);
349     ///     let handle = ylong_runtime::spawn(async move {
350     ///         let _ = poll_fn(|cx| rx.poll_notified(cx)).await;
351     ///         assert_eq!(*rx.borrow(), 2);
352     ///     });
353     ///     assert!(tx.send(2).is_ok());
354     /// }
355     /// ```
poll_notified(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), RecvError>>356     pub fn poll_notified(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), RecvError>> {
357         match self.try_notified() {
358             Some(Ok(())) => return Ready(Ok(())),
359             Some(Err(e)) => return Ready(Err(e)),
360             None => {}
361         }
362 
363         self.channel.waker_list.insert(cx.waker().clone());
364 
365         match self.try_notified() {
366             Some(Ok(())) => Ready(Ok(())),
367             Some(Err(e)) => Ready(Err(e)),
368             None => Pending,
369         }
370     }
371 
372     /// Waits for a value change notification from the associated [`Sender`],
373     /// and marks the value as seen then.
374     ///
375     /// If the channel has a value that has not yet seen, this method will
376     /// return immediately and mark the value as seen. If the value in the
377     /// channel has already been marked as seen, this method will wait
378     /// asynchronously until the next new value arrives or the channel is
379     /// closed.
380     ///
381     /// # Return value
382     /// * `Ok(())` if receiving a new value or the value in channel has not yet
383     ///   seen.
384     /// * `Err(RecvError)` The sender has been dropped or the channel is closed.
385     ///
386     /// # Examples
387     ///
388     /// ```
389     /// use ylong_runtime::sync::watch;
390     /// async fn io_func() {
391     ///     let (tx, mut rx) = watch::channel(1);
392     ///     ylong_runtime::spawn(async move {
393     ///         let _ = rx.notified().await;
394     ///         assert_eq!(*rx.borrow(), 2);
395     ///     });
396     ///
397     ///     let _ = tx.send(2);
398     /// }
399     /// ```
notified(&mut self) -> Result<(), RecvError>400     pub async fn notified(&mut self) -> Result<(), RecvError> {
401         poll_fn(|cx| self.poll_notified(cx)).await
402     }
403 
404     /// Gets a reference to the inner value.
405     ///
406     /// This method doesn't mark the value as seen, which means call to
407     /// [`notified`] may return `Ok(())` immediately and call to [`is_notified`]
408     /// may return `Ok(true)` after calling this method.
409     ///
410     /// The reference returned from this method will hold a read lock on the
411     /// internal value, so holding this reference will block the sender from
412     /// sending data. When the watch channel runs in an environment that
413     /// allows !Send futures, you need to ensure that the reference is not held
414     /// across an. wait point to avoid deadlocks.
415     ///
416     /// The priority policy of RwLock is consistent with the `std::RwLock`.
417     ///
418     /// # Examples
419     ///
420     /// ```
421     /// use ylong_runtime::sync::watch;
422     /// async fn io_func() {
423     ///     let (tx, mut rx) = watch::channel(1);
424     ///     ylong_runtime::spawn(async move {
425     ///         let _ = rx.notified().await;
426     ///         assert_eq!(*rx.borrow(), 2);
427     ///     });
428     ///
429     ///     let _ = tx.send(2);
430     /// }
431     /// ```
432     ///
433     /// [`notified`]: Receiver::notified
434     /// [`is_notified`]: Receiver::is_notified
borrow(&self) -> ValueRef<'_, T>435     pub fn borrow(&self) -> ValueRef<'_, T> {
436         let (value_version, _) = self.channel.state.load();
437         let value = self.channel.value.read().unwrap();
438         let is_notified = self.version != value_version;
439         ValueRef::new(value, is_notified)
440     }
441 
442     /// Gets a reference to the inner value and marks the value as seen.
443     ///
444     /// This method marks the value as seen, which means call to [`notified`]
445     /// will wait until the next message comes and call to [`is_notified`] won't
446     /// return `Ok(true)` after calling this method.
447     ///
448     /// The reference returned from this method will hold a read lock on the
449     /// internal value, so holding this reference will block the sender from
450     /// sending data. When the watch channel runs in an environment that
451     /// allows !Send futures, you need to ensure that the reference is not held
452     /// across an. wait point to avoid deadlocks.
453     ///
454     /// The priority policy of RwLock is consistent with the `std::RwLock`.
455     ///
456     /// # Examples
457     ///
458     /// ```
459     /// use ylong_runtime::sync::watch;
460     /// async fn io_func() {
461     ///     let (tx, mut rx) = watch::channel(1);
462     ///     ylong_runtime::spawn(async move {
463     ///         let _ = rx.notified().await;
464     ///         assert_eq!(*rx.borrow_notify(), 2);
465     ///     });
466     ///
467     ///     let _ = tx.send(2);
468     /// }
469     /// ```
470     ///
471     /// [`notified`]: Receiver::notified
472     /// [`is_notified`]: Receiver::is_notified
borrow_notify(&mut self) -> ValueRef<'_, T>473     pub fn borrow_notify(&mut self) -> ValueRef<'_, T> {
474         let (value_version, _) = self.channel.state.load();
475         let value = self.channel.value.read().unwrap();
476         let is_notified = self.version != value_version;
477         self.version = value_version;
478         ValueRef::new(value, is_notified)
479     }
480 
481     /// Checks whether the receiver and another receiver belong to the same
482     /// channel.
483     ///
484     /// # Examples
485     ///
486     /// ```
487     /// use ylong_runtime::sync::watch;
488     /// let (tx, rx) = watch::channel(1);
489     /// let rx2 = rx.clone();
490     /// assert!(rx.is_same(&rx2));
491     /// ```
is_same(&self, other: &Self) -> bool492     pub fn is_same(&self, other: &Self) -> bool {
493         Arc::ptr_eq(&self.channel, &other.channel)
494     }
495 }
496 
497 impl<T> Clone for Receiver<T> {
clone(&self) -> Self498     fn clone(&self) -> Self {
499         self.channel.rx_cnt.fetch_add(1, Release);
500         Self {
501             channel: self.channel.clone(),
502             version: self.version,
503         }
504     }
505 }
506 
507 impl<T> Drop for Receiver<T> {
drop(&mut self)508     fn drop(&mut self) {
509         self.channel.rx_cnt.fetch_sub(1, Release);
510     }
511 }
512 
513 struct State(AtomicUsize);
514 
515 impl State {
new() -> State516     fn new() -> State {
517         State(AtomicUsize::new(0))
518     }
519 
version_update(&self)520     fn version_update(&self) {
521         self.0.fetch_add(1 << VERSION_SHIFT, Release);
522     }
523 
load(&self) -> (usize, bool)524     fn load(&self) -> (usize, bool) {
525         let state = self.0.load(Acquire);
526         let version = state >> VERSION_SHIFT;
527         let is_closed = state & CLOSED == CLOSED;
528         (version, is_closed)
529     }
530 
close(&self)531     fn close(&self) {
532         self.0.fetch_or(CLOSED, Release);
533     }
534 }
535 
536 struct Channel<T> {
537     value: RwLock<T>,
538     waker_list: WakerList,
539     state: State,
540     rx_cnt: AtomicUsize,
541 }
542 
543 impl<T> Channel<T> {
new(value: T) -> Channel<T>544     fn new(value: T) -> Channel<T> {
545         Channel {
546             value: RwLock::new(value),
547             waker_list: WakerList::new(),
548             state: State::new(),
549             rx_cnt: AtomicUsize::new(1),
550         }
551     }
552 
close(&self)553     fn close(&self) {
554         self.state.close();
555         self.waker_list.notify_all();
556     }
557 }
558 
559 impl<T: Debug> Debug for Channel<T> {
fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result560     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
561         let (value_version, is_closed) = self.state.load();
562         f.debug_struct("Channel")
563             .field("value", &self.value)
564             .field("version", &value_version)
565             .field("is_closed", &is_closed)
566             .field("receiver_count", &self.rx_cnt.load(Acquire))
567             .finish()
568     }
569 }
570 
571 #[cfg(test)]
572 mod tests {
573     use std::sync::atomic::Ordering::Acquire;
574 
575     use crate::sync::error::RecvError;
576     use crate::sync::watch;
577     use crate::{block_on, spawn};
578 
579     /// UT test cases for `send()` and `try_notified()`.
580     ///
581     /// # Brief
582     /// 1. Call channel to create a sender and a receiver handle pair.
583     /// 2. Receiver tries receiving a change notification before the sender
584     ///    sends one.
585     /// 3. Receiver tries receiving a change notification after the sender sends
586     ///    one.
587     /// 4. Check if the test results are correct.
588     /// 5. Receiver tries receiving a change notification after the sender
589     ///    drops.
590     #[test]
send_try_notified()591     fn send_try_notified() {
592         let (tx, mut rx) = watch::channel("hello");
593         assert_eq!(rx.try_notified(), None);
594         assert!(tx.send("world").is_ok());
595         assert_eq!(rx.try_notified(), Some(Ok(())));
596         assert_eq!(*rx.borrow(), "world");
597 
598         drop(tx);
599         assert_eq!(rx.try_notified(), Some(Err(RecvError)));
600     }
601 
602     /// UT test cases for `send()` and async `notified()`.
603     /// .
604     /// # Brief
605     /// 1. Call channel to create a sender and a receiver handle pair.
606     /// 2. Sender sends message in one thread.
607     /// 3. Receiver waits for a notification in another thread.
608     /// 4. Check if the test results are correct.
609     /// 5. Receiver waits for a notification in another thread after the sender
610     ///    drops.
611     #[test]
send_notified_await()612     fn send_notified_await() {
613         let (tx, mut rx) = watch::channel("hello");
614         assert!(tx.send("world").is_ok());
615         drop(tx);
616         let handle1 = spawn(async move {
617             assert!(rx.notified().await.is_ok());
618             assert_eq!(*rx.borrow(), "world");
619             assert!(rx.notified().await.is_err());
620         });
621         let _ = block_on(handle1);
622     }
623 
624     /// UT test cases for `send()` and `borrow_notify()`.
625     ///
626     /// # Brief
627     /// 1. Call channel to create a sender and a receiver handle pair.
628     /// 2. Check whether receiver contains a value which has not been seen
629     ///    before and after `borrow()`.
630     /// 3. Check whether receiver contains a value which has not been seen
631     ///    before and after `borrow_notify()`.
632     #[test]
send_borrow_notify()633     fn send_borrow_notify() {
634         let (tx, mut rx) = watch::channel("hello");
635         assert_eq!(rx.is_notified(), Ok(false));
636         assert!(tx.send("world").is_ok());
637         assert_eq!(rx.is_notified(), Ok(true));
638         assert_eq!(*rx.borrow(), "world");
639         assert_eq!(rx.is_notified(), Ok(true));
640         assert_eq!(*rx.borrow_notify(), "world");
641         assert_eq!(rx.is_notified(), Ok(false));
642     }
643 
644     /// UT test cases for the count of the number of receivers.
645     ///
646     /// # Brief
647     /// 1. Call channel to create a sender and a receiver handle pair.
648     /// 2. Check whether receiver contains a value which has not been seen
649     ///    before and after `borrow()`.
650     /// 3. Check whether receiver contains a value which has not been seen
651     ///    before and after `borrow_notify()`.
652     #[test]
receiver_count()653     fn receiver_count() {
654         let (tx, rx) = watch::channel("hello");
655         assert_eq!(tx.channel.rx_cnt.load(Acquire), 1);
656         let rx2 = tx.subscribe();
657         assert_eq!(tx.channel.rx_cnt.load(Acquire), 2);
658         let rx3 = rx.clone();
659         assert_eq!(tx.channel.rx_cnt.load(Acquire), 3);
660         drop(rx);
661         drop(rx2);
662         drop(rx3);
663         assert_eq!(tx.channel.rx_cnt.load(Acquire), 0);
664     }
665 }
666