gwr_engine/
executor.rs

1// Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2
3use std::cell::RefCell;
4use std::future::Future;
5use std::mem;
6use std::pin::Pin;
7use std::rc::Rc;
8use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
9
10use gwr_track::entity::Entity;
11
12use crate::time::clock::Clock;
13use crate::time::simtime::SimTime;
14use crate::types::SimResult;
15
16fn no_op(_: *const ()) {}
17
18unsafe fn drop_task(data: *const ()) {
19    unsafe {
20        drop(Rc::from_raw(data as *const Task));
21    }
22}
23
24static VTABLE: RawWakerVTable = RawWakerVTable::new(clone_raw_waker, wake_task, no_op, drop_task);
25
26fn task_raw_waker(task: Rc<Task>) -> RawWaker {
27    let ptr = Rc::into_raw(task) as *const ();
28    RawWaker::new(ptr, &VTABLE)
29}
30
31fn waker_for_task(task: Rc<Task>) -> Waker {
32    unsafe { Waker::from_raw(task_raw_waker(task)) }
33}
34
35unsafe fn clone_raw_waker(data: *const ()) -> RawWaker {
36    unsafe {
37        // Tasks are always wrapped in a reference counter to allow them to be shared
38        // read-only. The input `data` pointer is borrowed — we must not decrement its
39        // refcount, so we mem::forget the reconstructed Rc rather than letting it drop.
40        let rc_task = Rc::from_raw(data as *const Task);
41        let clone = rc_task.clone();
42        mem::forget(rc_task);
43        let ptr = Rc::into_raw(clone) as *const ();
44        RawWaker::new(ptr, &VTABLE)
45    }
46}
47
48unsafe fn wake_task(data: *const ()) {
49    unsafe {
50        // Tasks are always wrapped in a reference counter to allow them to be shared
51        // read-only.
52        let rc_task = Rc::from_raw(data as *const Task);
53        let cloned = rc_task.clone();
54        rc_task.executor_state.new_tasks.borrow_mut().push(cloned);
55    }
56}
57
58struct Task {
59    future: RefCell<Pin<Box<dyn Future<Output = SimResult>>>>,
60    executor_state: Rc<ExecutorState>,
61}
62
63impl Task {
64    pub fn new(
65        future: impl Future<Output = SimResult> + 'static,
66        executor_state: Rc<ExecutorState>,
67    ) -> Task {
68        Task {
69            future: RefCell::new(Box::pin(future)),
70            executor_state,
71        }
72    }
73
74    fn poll(&self, context: &mut Context) -> Poll<SimResult> {
75        self.future.borrow_mut().as_mut().poll(context)
76    }
77}
78
79struct ExecutorState {
80    task_queue: RefCell<Vec<Rc<Task>>>,
81    new_tasks: RefCell<Vec<Rc<Task>>>,
82    time: RefCell<SimTime>,
83}
84
85impl ExecutorState {
86    pub fn new(top: &Rc<Entity>) -> Self {
87        Self {
88            task_queue: RefCell::new(Vec::new()),
89            new_tasks: RefCell::new(Vec::new()),
90            time: RefCell::new(SimTime::new(top)),
91        }
92    }
93}
94
95/// Single-threaded executor
96///
97/// This is a thin-wrapper (using [`Rc`]) around the real executor, so that this
98/// struct can be cloned and passed around.
99///
100/// See the [module documentation] for more details.
101///
102/// [module documentation]: index.html
103#[derive(Clone)]
104pub struct Executor {
105    state: Rc<ExecutorState>,
106}
107
108impl Executor {
109    pub fn run(&self, finished: &Rc<RefCell<bool>>) -> SimResult {
110        loop {
111            self.step(finished)?;
112            if *finished.borrow() {
113                break;
114            }
115
116            if self.state.new_tasks.borrow().is_empty() {
117                if self.state.time.borrow().can_exit() {
118                    break;
119                }
120
121                if let Some(wakers) = self.state.time.borrow_mut().advance_time() {
122                    // No events left, advance time
123                    for task_waker in wakers.into_iter() {
124                        task_waker.waker.wake();
125                    }
126                } else {
127                    break;
128                }
129            }
130        }
131        Ok(())
132    }
133
134    pub fn step(&self, finished: &Rc<RefCell<bool>>) -> SimResult {
135        // Append new tasks created since the last step into the task queue
136        let mut task_queue = self.state.task_queue.borrow_mut();
137        task_queue.append(&mut self.state.new_tasks.borrow_mut());
138
139        // Loop over all tasks, polling them. If a task is not ready, add it to the
140        // pending tasks.
141        for task in task_queue.drain(..) {
142            if *finished.borrow() {
143                break;
144            }
145
146            // Dummy waker and context (not used as we poll all tasks)
147            let waker = waker_for_task(task.clone());
148            let mut context = Context::from_waker(&waker);
149
150            match task.poll(&mut context) {
151                Poll::Ready(Err(e)) => {
152                    // Error - return early
153                    return Err(e);
154                }
155                Poll::Ready(Ok(())) => {
156                    // Otherwise, drop task as it is complete
157                }
158                Poll::Pending => {
159                    // Task will have parked itself waiting somewhere
160                }
161            }
162        }
163        Ok(())
164    }
165
166    #[must_use]
167    pub fn get_clock(&self, freq_mhz: f64) -> Clock {
168        self.state.time.borrow_mut().get_clock(freq_mhz)
169    }
170
171    #[must_use]
172    pub fn time_now_ns(&self) -> f64 {
173        self.state.time.borrow().time_now_ns()
174    }
175}
176
177/// `Spawner` spawns new futures into the executor.
178#[derive(Clone)]
179pub struct Spawner {
180    state: Rc<ExecutorState>,
181}
182
183impl Spawner {
184    pub fn spawn(&self, future: impl Future<Output = SimResult> + 'static) {
185        self.state
186            .new_tasks
187            .borrow_mut()
188            .push(Rc::new(Task::new(future, self.state.clone())));
189    }
190}
191
192#[must_use]
193pub fn new_executor_and_spawner(top: &Rc<Entity>) -> (Executor, Spawner) {
194    let state = Rc::new(ExecutorState::new(top));
195    (
196        Executor {
197            state: state.clone(),
198        },
199        Spawner { state },
200    )
201}