gwr_engine/port/
mod.rs

1// Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2
3//! Port
4
5use std::cell::RefCell;
6use std::fmt;
7use std::pin::Pin;
8use std::rc::Rc;
9use std::task::{Context, Poll, Waker};
10
11use futures::Future;
12use futures::future::FusedFuture;
13use gwr_track::connect;
14use gwr_track::entity::{Entity, GetEntity};
15use gwr_track::tracker::aka::Aka;
16
17use crate::engine::Engine;
18use crate::port::monitor::Monitor;
19use crate::sim_error;
20use crate::time::clock::Clock;
21use crate::traits::SimObject;
22use crate::types::{SimError, SimResult};
23
24pub mod monitor;
25
26pub type PortStateResult<T> = Result<Rc<PortState<T>>, SimError>;
27pub type PortGetResult<T> = Result<PortGet<T>, SimError>;
28pub type PortStartGetResult<T> = Result<PortStartGet<T>, SimError>;
29pub type PortPutResult<T> = Result<PortPut<T>, SimError>;
30pub type PortTryPutResult<T> = Result<PortTryPut<T>, SimError>;
31
32pub struct PortState<T>
33where
34    T: SimObject,
35{
36    value: RefCell<Option<T>>,
37    waiting_get: RefCell<Option<Waker>>,
38    waiting_put: RefCell<Option<Waker>>,
39    pub in_port_entity: Rc<Entity>,
40    monitor: Option<Rc<Monitor>>,
41}
42
43impl<T> PortState<T>
44where
45    T: SimObject,
46{
47    fn new(
48        engine: &Engine,
49        clock: &Clock,
50        in_port_entity: Rc<Entity>,
51        window_size_ticks: Option<u64>,
52    ) -> Self {
53        let monitor = window_size_ticks.map(|window_size_ticks| {
54            Monitor::new_and_register(engine, &in_port_entity, clock, window_size_ticks)
55        });
56        Self {
57            value: RefCell::new(None),
58            waiting_get: RefCell::new(None),
59            waiting_put: RefCell::new(None),
60            in_port_entity,
61            monitor,
62        }
63    }
64}
65
66pub struct InPort<T>
67where
68    T: SimObject,
69{
70    entity: Rc<Entity>,
71    state: Rc<PortState<T>>,
72    connected: RefCell<bool>,
73}
74
75impl<T> fmt::Display for InPort<T>
76where
77    T: SimObject,
78{
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        self.entity.fmt(f)
81    }
82}
83
84impl<T> InPort<T>
85where
86    T: SimObject,
87{
88    #[must_use]
89    pub fn new(engine: &Engine, clock: &Clock, parent: &Rc<Entity>, name: &str) -> Self {
90        Self::new_with_renames(engine, clock, parent, name, None)
91    }
92
93    #[must_use]
94    pub fn new_with_renames(
95        engine: &Engine,
96        clock: &Clock,
97        parent: &Rc<Entity>,
98        name: &str,
99        aka: Option<&Aka>,
100    ) -> Self {
101        let entity = Rc::new(Entity::new_with_renames(parent, name, aka));
102        let monitor_window_size = entity.tracker.monitoring_window_size_for(entity.id);
103        Self {
104            entity: entity.clone(),
105            state: Rc::new(PortState::new(engine, clock, entity, monitor_window_size)),
106            connected: RefCell::new(false),
107        }
108    }
109
110    pub fn state(&self) -> PortStateResult<T> {
111        if *self.connected.borrow() {
112            return sim_error!("{self} already connected");
113        }
114
115        *self.connected.borrow_mut() = true;
116        Ok(self.state.clone())
117    }
118
119    #[must_use = "Futures do nothing unless you `.await` or otherwise use them"]
120    pub fn get(&self) -> PortGetResult<T> {
121        if !*self.connected.borrow() {
122            return sim_error!("{self} not connected");
123        }
124
125        Ok(PortGet {
126            state: self.state.clone(),
127            done: false,
128        })
129    }
130
131    /// Must be matched with a `finish_get` to allow the OutPort to continue.
132    #[must_use = "Futures do nothing unless you `.await` or otherwise use them"]
133    pub fn start_get(&self) -> PortStartGetResult<T> {
134        if !*self.connected.borrow() {
135            return sim_error!("{self} not connected");
136        }
137
138        Ok(PortStartGet {
139            state: self.state.clone(),
140            done: false,
141        })
142    }
143
144    /// Must be matched with a `start_get ` to consume the value.
145    pub fn finish_get(&self) {
146        if let Some(waker) = self.state.waiting_put.borrow_mut().take() {
147            waker.wake();
148        }
149    }
150}
151
152pub struct OutPort<T>
153where
154    T: SimObject,
155{
156    entity: Rc<Entity>,
157    state: Option<Rc<PortState<T>>>,
158}
159
160impl<T> GetEntity for OutPort<T>
161where
162    T: SimObject,
163{
164    fn entity(&self) -> &Rc<Entity> {
165        &self.entity
166    }
167}
168
169impl<T> fmt::Display for OutPort<T>
170where
171    T: SimObject,
172{
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        self.entity.fmt(f)
175    }
176}
177
178impl<T> OutPort<T>
179where
180    T: SimObject,
181{
182    #[must_use]
183    pub fn new(parent: &Rc<Entity>, name: &str) -> Self {
184        Self::new_with_renames(parent, name, None)
185    }
186
187    #[must_use]
188    pub fn new_with_renames(parent: &Rc<Entity>, name: &str, aka: Option<&Aka>) -> Self {
189        let entity = Rc::new(Entity::new_with_renames(parent, name, aka));
190        Self {
191            entity,
192            state: None,
193        }
194    }
195
196    pub fn connect(&mut self, port_state: PortStateResult<T>) -> SimResult {
197        let port_state = port_state?;
198
199        connect!(self.entity ; port_state.in_port_entity);
200        match self.state {
201            Some(_) => {
202                return sim_error!("{self} already connected");
203            }
204            None => {
205                self.state = Some(port_state);
206            }
207        }
208        Ok(())
209    }
210
211    #[must_use = "Futures do nothing unless you `.await` or otherwise use them"]
212    pub fn put(&self, value: T) -> PortPutResult<T> {
213        let state = match self.state.as_ref() {
214            Some(s) => s.clone(),
215            None => return sim_error!("{self} not connected"),
216        };
217        Ok(PortPut {
218            state,
219            value: RefCell::new(Some(value)),
220            done: RefCell::new(false),
221        })
222    }
223
224    #[must_use = "Futures do nothing unless you `.await` or otherwise use them"]
225    pub fn try_put(&self) -> PortTryPutResult<T> {
226        let state = match self.state.as_ref() {
227            Some(s) => s.clone(),
228            None => return sim_error!("{self} not connected"),
229        };
230        Ok(PortTryPut { state, done: false })
231    }
232}
233
234pub struct PortPut<T>
235where
236    T: SimObject,
237{
238    state: Rc<PortState<T>>,
239    value: RefCell<Option<T>>,
240    done: RefCell<bool>,
241}
242
243impl<T> Future for PortPut<T>
244where
245    T: SimObject,
246{
247    type Output = ();
248
249    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
250        match self.value.take() {
251            Some(value) => {
252                // The state is designed to be shared between one put/get pair so it should
253                // not be possible for the value in the state to be set at this point.
254                assert!(self.state.value.borrow().is_none());
255
256                *self.state.value.borrow_mut() = Some(value);
257                if let Some(waker) = self.state.waiting_get.borrow_mut().take() {
258                    waker.wake();
259                }
260                *self.state.waiting_put.borrow_mut() = Some(cx.waker().clone());
261                Poll::Pending
262            }
263            None => {
264                // Value already sent, woken because it has been consumed
265                *self.done.borrow_mut() = true;
266                Poll::Ready(())
267            }
268        }
269    }
270}
271
272impl<T> FusedFuture for PortPut<T>
273where
274    T: SimObject,
275{
276    fn is_terminated(&self) -> bool {
277        *self.done.borrow()
278    }
279}
280
281pub struct PortTryPut<T>
282where
283    T: SimObject,
284{
285    state: Rc<PortState<T>>,
286    done: bool,
287}
288
289impl<T> Future for PortTryPut<T>
290where
291    T: SimObject,
292{
293    type Output = ();
294
295    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
296        if self.state.waiting_get.borrow().is_some() {
297            self.done = true;
298            Poll::Ready(())
299        } else {
300            *self.state.waiting_put.borrow_mut() = Some(cx.waker().clone());
301            Poll::Pending
302        }
303    }
304}
305
306impl<T> FusedFuture for PortTryPut<T>
307where
308    T: SimObject,
309{
310    fn is_terminated(&self) -> bool {
311        self.done
312    }
313}
314
315pub struct PortGet<T>
316where
317    T: SimObject,
318{
319    state: Rc<PortState<T>>,
320    done: bool,
321}
322
323impl<T> Future for PortGet<T>
324where
325    T: SimObject,
326{
327    type Output = T;
328
329    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
330        let value = self.state.value.borrow_mut().take();
331        if let Some(value) = value {
332            self.done = true;
333            self.state.waiting_get.borrow_mut().take();
334
335            // Track the object through the port monitor if there is one
336            if let Some(monitor) = self.state.monitor.as_ref() {
337                monitor.sample(&value);
338            }
339
340            if let Some(waker) = self.state.waiting_put.borrow_mut().take() {
341                waker.wake();
342            }
343            Poll::Ready(value)
344        } else {
345            if let Some(waker) = self.state.waiting_put.borrow_mut().take() {
346                waker.wake();
347            }
348
349            *self.state.waiting_get.borrow_mut() = Some(cx.waker().clone());
350            Poll::Pending
351        }
352    }
353}
354
355impl<T> FusedFuture for PortGet<T>
356where
357    T: SimObject,
358{
359    fn is_terminated(&self) -> bool {
360        self.done
361    }
362}
363
364pub struct PortStartGet<T>
365where
366    T: SimObject,
367{
368    state: Rc<PortState<T>>,
369    done: bool,
370}
371
372impl<T> Future for PortStartGet<T>
373where
374    T: SimObject,
375{
376    type Output = T;
377
378    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
379        let value = self.state.value.borrow_mut().take();
380        if let Some(value) = value {
381            self.done = true;
382            self.state.waiting_get.borrow_mut().take();
383
384            // Track the object through the port monitor if there is one
385            if let Some(monitor) = self.state.monitor.as_ref() {
386                monitor.sample(&value);
387            }
388
389            Poll::Ready(value)
390        } else {
391            *self.state.waiting_get.borrow_mut() = Some(cx.waker().clone());
392            Poll::Pending
393        }
394    }
395}
396
397impl<T> FusedFuture for PortStartGet<T>
398where
399    T: SimObject,
400{
401    fn is_terminated(&self) -> bool {
402        self.done
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use std::sync::Arc;
409    use std::sync::atomic::{AtomicUsize, Ordering};
410    use std::task::{Wake, Waker};
411
412    use futures::future::FusedFuture;
413    use futures::task::noop_waker;
414    use gwr_track::Tracker;
415    use gwr_track::entity::Entity;
416    use gwr_track::tracker::dev_null_tracker;
417
418    use super::*;
419    use crate::traits::TotalBytes;
420
421    struct TestContext {
422        // Just kept to ensure it isn't dropped
423        _tracker: Tracker,
424        engine: Engine,
425        clock: Clock,
426    }
427
428    fn test_context() -> TestContext {
429        let tracker = dev_null_tracker();
430        let mut engine = Engine::new(&tracker);
431        let clock = engine.default_clock();
432
433        TestContext {
434            _tracker: tracker,
435            engine,
436            clock,
437        }
438    }
439
440    fn test_state<T: SimObject>() -> Rc<PortState<T>> {
441        let context = test_context();
442        let entity = Rc::new(Entity::new(context.engine.top(), "rx"));
443
444        Rc::new(PortState::new(
445            &context.engine,
446            &context.clock,
447            entity,
448            None,
449        ))
450    }
451
452    fn monitored_test_state<T: SimObject>() -> Rc<PortState<T>> {
453        let context = test_context();
454        let entity = Rc::new(Entity::new(context.engine.top(), "rx"));
455
456        Rc::new(PortState::new(
457            &context.engine,
458            &context.clock,
459            entity,
460            Some(1),
461        ))
462    }
463
464    struct WakeCounter {
465        wakes_count: Arc<AtomicUsize>,
466    }
467
468    impl Wake for WakeCounter {
469        fn wake(self: Arc<Self>) {
470            self.wakes_count.fetch_add(1, Ordering::SeqCst);
471        }
472
473        fn wake_by_ref(self: &Arc<Self>) {
474            self.wakes_count.fetch_add(1, Ordering::SeqCst);
475        }
476    }
477
478    fn counting_waker() -> (Arc<AtomicUsize>, Waker) {
479        let wakes_count = Arc::new(AtomicUsize::new(0));
480        let waker = Waker::from(Arc::new(WakeCounter {
481            wakes_count: wakes_count.clone(),
482        }));
483
484        (wakes_count, waker)
485    }
486
487    #[test]
488    fn wake_counter_counts_wake_and_wake_by_ref() {
489        let (wakes_count, waker) = counting_waker();
490
491        waker.wake_by_ref();
492        assert_eq!(wakes_count.load(Ordering::SeqCst), 1);
493
494        waker.wake();
495        assert_eq!(wakes_count.load(Ordering::SeqCst), 2);
496    }
497
498    #[test]
499    fn in_port_state_can_only_connect_once() {
500        let context = test_context();
501        let in_port =
502            InPort::<i32>::new(&context.engine, &context.clock, context.engine.top(), "rx");
503
504        assert!(in_port.state().is_ok());
505
506        let err = in_port
507            .state()
508            .err()
509            .expect("second state call should fail");
510        assert!(format!("{err}").contains("already connected"));
511    }
512
513    #[test]
514    fn out_port_connect_can_only_connect_once() {
515        let context = test_context();
516        let mut out_port = OutPort::<i32>::new(context.engine.top(), "tx");
517        let first_in_port =
518            InPort::new(&context.engine, &context.clock, context.engine.top(), "rx1");
519        let second_in_port =
520            InPort::new(&context.engine, &context.clock, context.engine.top(), "rx2");
521
522        out_port.connect(first_in_port.state()).unwrap();
523
524        let err = out_port.connect(second_in_port.state()).unwrap_err();
525        assert!(format!("{err}").contains("already connected"));
526    }
527
528    #[test]
529    fn out_port_entity_returns_port_entity() {
530        let context = test_context();
531        let out_port = OutPort::<i32>::new(context.engine.top(), "tx");
532
533        assert!(Rc::ptr_eq(out_port.entity(), &out_port.entity));
534    }
535
536    #[test]
537    fn start_get_requires_connection_and_finish_get_wakes_putter() {
538        let context = test_context();
539        let in_port =
540            InPort::<i32>::new(&context.engine, &context.clock, context.engine.top(), "rx");
541
542        assert!(in_port.start_get().is_err());
543        assert!(in_port.state().is_ok());
544        assert!(in_port.start_get().is_ok());
545
546        let waker = noop_waker();
547        *in_port.state.waiting_put.borrow_mut() = Some(waker);
548        in_port.finish_get();
549
550        assert!(in_port.state.waiting_put.borrow().is_none());
551    }
552
553    #[test]
554    fn finish_get_without_waiting_putter_is_a_noop() {
555        let context = test_context();
556        let in_port =
557            InPort::<i32>::new(&context.engine, &context.clock, context.engine.top(), "rx");
558
559        in_port.finish_get();
560
561        assert!(in_port.state.waiting_put.borrow().is_none());
562    }
563
564    #[test]
565    fn port_put_reports_termination_after_second_poll() {
566        let state = test_state::<i32>();
567        let put = PortPut {
568            state: state.clone(),
569            value: RefCell::new(Some(123)),
570            done: RefCell::new(false),
571        };
572        let mut put = Box::pin(put);
573        let waker = noop_waker();
574        let mut cx = Context::from_waker(&waker);
575
576        assert_eq!(put.as_mut().poll(&mut cx), Poll::Pending);
577        assert!(!put.is_terminated());
578        assert_eq!(*state.value.borrow(), Some(123));
579        assert!(state.waiting_put.borrow().is_some());
580
581        assert_eq!(put.as_mut().poll(&mut cx), Poll::Ready(()));
582        assert!(put.is_terminated());
583    }
584
585    #[test]
586    fn port_try_put_waits_for_getter_then_completes() {
587        let state = test_state::<i32>();
588        let try_put = PortTryPut {
589            state: state.clone(),
590            done: false,
591        };
592        let mut try_put = Box::pin(try_put);
593        let waker = noop_waker();
594        let mut cx = Context::from_waker(&waker);
595
596        assert_eq!(try_put.as_mut().poll(&mut cx), Poll::Pending);
597        assert!(!try_put.is_terminated());
598        assert!(state.waiting_put.borrow().is_some());
599
600        *state.waiting_get.borrow_mut() = Some(noop_waker());
601
602        assert_eq!(try_put.as_mut().poll(&mut cx), Poll::Ready(()));
603        assert!(try_put.is_terminated());
604    }
605
606    #[test]
607    fn connected_out_port_creates_try_put_future() {
608        let context = test_context();
609        let mut out_port = OutPort::<i32>::new(context.engine.top(), "tx");
610        let in_port = InPort::new(&context.engine, &context.clock, context.engine.top(), "rx");
611
612        out_port.connect(in_port.state()).unwrap();
613
614        assert!(out_port.try_put().is_ok());
615    }
616
617    #[test]
618    fn port_get_waits_then_returns_value_and_reports_termination() {
619        let state = test_state::<i32>();
620        let get = PortGet {
621            state: state.clone(),
622            done: false,
623        };
624        let mut get = Box::pin(get);
625        let waker = noop_waker();
626        let mut cx = Context::from_waker(&waker);
627
628        assert_eq!(get.as_mut().poll(&mut cx), Poll::Pending);
629        assert!(!get.is_terminated());
630        assert!(state.waiting_get.borrow().is_some());
631
632        *state.value.borrow_mut() = Some(456);
633        *state.waiting_put.borrow_mut() = Some(noop_waker());
634
635        assert_eq!(get.as_mut().poll(&mut cx), Poll::Ready(456));
636        assert!(get.is_terminated());
637        assert!(state.waiting_put.borrow().is_none());
638    }
639
640    #[test]
641    fn port_get_pending_wakes_waiting_putter() {
642        let state = test_state::<i32>();
643        let get = PortGet {
644            state: state.clone(),
645            done: false,
646        };
647        let mut get = Box::pin(get);
648        let waker = noop_waker();
649        let mut cx = Context::from_waker(&waker);
650        *state.waiting_put.borrow_mut() = Some(noop_waker());
651
652        assert_eq!(get.as_mut().poll(&mut cx), Poll::Pending);
653
654        assert!(state.waiting_put.borrow().is_none());
655        assert!(state.waiting_get.borrow().is_some());
656    }
657
658    #[test]
659    fn port_get_samples_monitored_values() {
660        let state = monitored_test_state::<i32>();
661        let monitor = state
662            .monitor
663            .as_ref()
664            .expect("monitored state should create a monitor");
665        let get = PortGet {
666            state: state.clone(),
667            done: false,
668        };
669        let mut get = Box::pin(get);
670        let waker = noop_waker();
671        let mut cx = Context::from_waker(&waker);
672        *state.value.borrow_mut() = Some(456);
673
674        assert_eq!(get.as_mut().poll(&mut cx), Poll::Ready(456));
675        assert_eq!(monitor.bytes_in_window(), 456_i32.total_bytes());
676    }
677
678    #[test]
679    fn port_start_get_waits_then_returns_value_without_finishing_put() {
680        let state = test_state::<i32>();
681        let start_get = PortStartGet {
682            state: state.clone(),
683            done: false,
684        };
685        let mut start_get = Box::pin(start_get);
686        let waker = noop_waker();
687        let mut cx = Context::from_waker(&waker);
688
689        assert_eq!(start_get.as_mut().poll(&mut cx), Poll::Pending);
690        assert!(!start_get.is_terminated());
691        assert!(state.waiting_get.borrow().is_some());
692
693        let (waiting_put_wakes, waiting_put_waker) = counting_waker();
694        *state.waiting_put.borrow_mut() = Some(waiting_put_waker.clone());
695        *state.value.borrow_mut() = Some(789);
696
697        assert_eq!(start_get.as_mut().poll(&mut cx), Poll::Ready(789));
698        assert!(start_get.is_terminated());
699        assert!(state.waiting_get.borrow().is_none());
700        assert_eq!(waiting_put_wakes.load(Ordering::SeqCst), 0);
701    }
702
703    #[test]
704    fn port_start_get_samples_monitored_values() {
705        let state = monitored_test_state::<i32>();
706        let monitor = state
707            .monitor
708            .as_ref()
709            .expect("monitored state should create a monitor");
710        let start_get = PortStartGet {
711            state: state.clone(),
712            done: false,
713        };
714        let mut start_get = Box::pin(start_get);
715        let waker = noop_waker();
716        let mut cx = Context::from_waker(&waker);
717        *state.value.borrow_mut() = Some(789);
718
719        assert_eq!(start_get.as_mut().poll(&mut cx), Poll::Ready(789));
720        assert_eq!(monitor.bytes_in_window(), 789_i32.total_bytes());
721    }
722}