gwr_engine/
executor.rs

1// Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2
3use std::cell::{Cell, RefCell};
4use std::future::Future;
5use std::mem;
6use std::pin::Pin;
7use std::rc::Rc;
8use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
9
10use gwr_track::entity::Entity;
11use rand::SeedableRng;
12use rand::rngs::StdRng;
13use rand::seq::SliceRandom;
14
15use crate::time::clock::Clock;
16use crate::time::simtime::SimTime;
17use crate::types::SimResult;
18
19fn no_op(_: *const ()) {}
20
21unsafe fn drop_task(data: *const ()) {
22    unsafe {
23        drop(Rc::from_raw(data as *const Task));
24    }
25}
26
27static VTABLE: RawWakerVTable = RawWakerVTable::new(clone_raw_waker, wake_task, no_op, drop_task);
28
29fn task_raw_waker(task: Rc<Task>) -> RawWaker {
30    let ptr = Rc::into_raw(task) as *const ();
31    RawWaker::new(ptr, &VTABLE)
32}
33
34fn waker_for_task(task: Rc<Task>) -> Waker {
35    unsafe { Waker::from_raw(task_raw_waker(task)) }
36}
37
38unsafe fn clone_raw_waker(data: *const ()) -> RawWaker {
39    unsafe {
40        // Tasks are always wrapped in a reference counter to allow them to be shared
41        // read-only. The input `data` pointer is borrowed — we must not decrement its
42        // refcount, so we mem::forget the reconstructed Rc rather than letting it drop.
43        let rc_task = Rc::from_raw(data as *const Task);
44        let clone = rc_task.clone();
45        mem::forget(rc_task);
46        let ptr = Rc::into_raw(clone) as *const ();
47        RawWaker::new(ptr, &VTABLE)
48    }
49}
50
51unsafe fn wake_task(data: *const ()) {
52    unsafe {
53        // Tasks are always wrapped in a reference counter to allow them to be shared
54        // read-only.
55        let rc_task = Rc::from_raw(data as *const Task);
56        let cloned = rc_task.clone();
57        rc_task.executor_state.new_tasks.borrow_mut().push(cloned);
58    }
59}
60
61struct Task {
62    future: RefCell<Option<Pin<Box<dyn Future<Output = SimResult>>>>>,
63    executor_state: Rc<ExecutorState>,
64}
65
66impl Task {
67    pub fn new(
68        future: impl Future<Output = SimResult> + 'static,
69        executor_state: Rc<ExecutorState>,
70    ) -> Task {
71        Task {
72            future: RefCell::new(Some(Box::pin(future))),
73            executor_state,
74        }
75    }
76
77    fn poll(&self, context: &mut Context) -> Poll<SimResult> {
78        let mut future_slot = self.future.borrow_mut();
79        let Some(future) = future_slot.as_mut() else {
80            return Poll::Ready(Ok(()));
81        };
82
83        let poll_result = future.as_mut().poll(context);
84        if poll_result.is_ready() {
85            future_slot.take();
86        }
87
88        poll_result
89    }
90}
91
92struct ExecutorState {
93    task_queue: RefCell<Vec<Rc<Task>>>,
94    new_tasks: RefCell<Vec<Rc<Task>>>,
95    time: RefCell<SimTime>,
96    randomize_task_order: Cell<bool>,
97    task_order_rng: RefCell<StdRng>,
98}
99
100impl ExecutorState {
101    pub fn new(top: &Rc<Entity>) -> Self {
102        Self {
103            task_queue: RefCell::new(Vec::new()),
104            new_tasks: RefCell::new(Vec::new()),
105            time: RefCell::new(SimTime::new(top)),
106            randomize_task_order: Cell::new(false),
107            task_order_rng: RefCell::new(StdRng::seed_from_u64(rand::random())),
108        }
109    }
110}
111
112/// Single-threaded executor
113///
114/// This is a thin-wrapper (using [`Rc`]) around the real executor, so that this
115/// struct can be cloned and passed around.
116///
117/// See the [module documentation] for more details.
118///
119/// [module documentation]: index.html
120#[derive(Clone)]
121pub struct Executor {
122    state: Rc<ExecutorState>,
123}
124
125impl Executor {
126    pub fn run(&self, finished: &Rc<RefCell<bool>>) -> SimResult {
127        loop {
128            self.step(finished)?;
129            if *finished.borrow() {
130                break;
131            }
132
133            if self.state.new_tasks.borrow().is_empty() {
134                if self.state.time.borrow().can_exit() {
135                    break;
136                }
137
138                if let Some(wakers) = self.state.time.borrow_mut().advance_time() {
139                    // No events left, advance time
140                    for task_waker in wakers.into_iter() {
141                        task_waker.waker.wake();
142                    }
143                } else {
144                    break;
145                }
146            }
147        }
148        Ok(())
149    }
150
151    pub fn step(&self, finished: &Rc<RefCell<bool>>) -> SimResult {
152        // Append new tasks created since the last step into the task queue
153        let mut task_queue = self.state.task_queue.borrow_mut();
154        task_queue.append(&mut self.state.new_tasks.borrow_mut());
155        if self.state.randomize_task_order.get() {
156            task_queue.shuffle(&mut *self.state.task_order_rng.borrow_mut());
157        }
158
159        // Loop over all tasks, polling them. If a task is not ready, add it to the
160        // pending tasks.
161        for task in task_queue.drain(..) {
162            if *finished.borrow() {
163                break;
164            }
165
166            // Dummy waker and context (not used as we poll all tasks)
167            let waker = waker_for_task(task.clone());
168            let mut context = Context::from_waker(&waker);
169
170            match task.poll(&mut context) {
171                Poll::Ready(Err(e)) => {
172                    // Error - return early
173                    return Err(e);
174                }
175                Poll::Ready(Ok(())) => {
176                    // Otherwise, drop task as it is complete
177                }
178                Poll::Pending => {
179                    // Task will have parked itself waiting somewhere
180                }
181            }
182        }
183        Ok(())
184    }
185
186    #[must_use]
187    pub fn get_clock(&self, freq_mhz: f64) -> Clock {
188        self.state.time.borrow_mut().get_clock(freq_mhz)
189    }
190
191    #[must_use]
192    pub fn time_now_ns(&self) -> f64 {
193        self.state.time.borrow().time_now_ns()
194    }
195
196    pub fn set_randomize_task_order(&self, randomize: bool) {
197        self.state.randomize_task_order.set(randomize);
198    }
199
200    pub fn set_task_order_seed(&self, seed: u64) {
201        *self.state.task_order_rng.borrow_mut() = StdRng::seed_from_u64(seed);
202    }
203}
204
205/// `Spawner` spawns new futures into the executor.
206#[derive(Clone)]
207pub struct Spawner {
208    state: Rc<ExecutorState>,
209}
210
211impl Spawner {
212    pub fn spawn(&self, future: impl Future<Output = SimResult> + 'static) {
213        self.state
214            .new_tasks
215            .borrow_mut()
216            .push(Rc::new(Task::new(future, self.state.clone())));
217    }
218}
219
220#[must_use]
221pub fn new_executor_and_spawner(top: &Rc<Entity>) -> (Executor, Spawner) {
222    let state = Rc::new(ExecutorState::new(top));
223    (
224        Executor {
225            state: state.clone(),
226        },
227        Spawner { state },
228    )
229}