gwr_components/
test_helpers.rs

1// Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2
3use std::cmp::min;
4use std::collections::HashMap;
5use std::rc::Rc;
6
7use gwr_engine::engine::Engine;
8use gwr_engine::port::InPort;
9use gwr_track::entity::Entity;
10
11use crate::arbiter::Arbiter;
12use crate::arbiter::policy::{Priority, PriorityRoundRobin};
13use crate::flow_controls::limiter::Limiter;
14use crate::source::Source;
15use crate::store::Store;
16use crate::{connect_port, option_box_repeat, rc_limiter};
17
18#[derive(Clone)]
19pub struct ArbiterInputData {
20    pub val: usize,
21    pub count: usize,
22    pub weight: usize,
23    pub priority: Priority,
24}
25
26pub fn check_round_robin(inputs: &[ArbiterInputData], data: &[usize]) {
27    let total_count: usize = inputs.iter().map(|i| i.count).sum();
28    assert_eq!(data.len(), total_count);
29
30    let mut inputs = inputs.to_vec();
31    let mut offset = 0;
32    loop {
33        // Determine the count for each input value in the next window. Note that this
34        // copes with inputs producing the same value and inputs not producing
35        // their full weight in the window.
36        let mut expected_window_counts: HashMap<usize, usize> = HashMap::new();
37        let mut window_length = 0;
38        let max_priority = inputs
39            .iter()
40            .map(|i| {
41                if i.count > 0 {
42                    i.priority
43                } else {
44                    Priority::default()
45                }
46            })
47            .max()
48            .unwrap();
49        for input in &mut inputs {
50            let value_count = min(input.count, input.weight);
51            if input.priority == max_priority && value_count > 0 {
52                expected_window_counts
53                    .entry(input.val)
54                    .and_modify(|e| *e += value_count)
55                    .or_insert(value_count);
56
57                window_length += value_count;
58                input.count -= value_count;
59            }
60        }
61        if window_length == 0 {
62            return;
63        }
64
65        let mut window_counts = HashMap::new();
66        for value in data.iter().skip(offset).take(window_length) {
67            window_counts
68                .entry(*value)
69                .and_modify(|e| *e += 1)
70                .or_insert(1);
71        }
72        assert_eq!(window_counts, expected_window_counts);
73
74        offset += window_length;
75    }
76}
77
78pub fn priority_policy_test_core(engine: &mut Engine, inputs: &[ArbiterInputData]) {
79    let clock = engine.default_clock();
80    let spawner = engine.spawner();
81    let num_inputs = inputs.len();
82    let total_count = inputs.iter().map(|e| e.count).sum();
83    let mut policy = PriorityRoundRobin::new(num_inputs);
84    for (i, input) in inputs.iter().enumerate() {
85        policy = policy.set_priority(i, input.priority);
86    }
87
88    let arbiter = Arbiter::new_and_register(
89        engine,
90        engine.top(),
91        "arb",
92        spawner.clone(),
93        num_inputs,
94        Box::new(policy),
95    )
96    .unwrap();
97    let mut sources = Vec::new();
98    for (i, input) in inputs.iter().enumerate() {
99        sources.push(
100            Source::new_and_register(
101                engine,
102                engine.top(),
103                &("source_".to_owned() + &i.to_string()),
104                option_box_repeat!(input.val; input.count),
105            )
106            .unwrap(),
107        );
108    }
109
110    let write_limiter = rc_limiter!(clock, 1);
111    let store_limiter =
112        Limiter::new_and_register(engine, engine.top(), "limit_wr", write_limiter).unwrap();
113    let store =
114        Store::new_and_register(engine, engine.top(), "store", spawner, total_count).unwrap();
115    connect_port!(store_limiter, tx => store, rx).unwrap();
116
117    for (i, source) in sources.iter_mut().enumerate() {
118        connect_port!(source, tx => arbiter, rx, i).unwrap();
119    }
120    connect_port!(arbiter, tx => store_limiter, rx).unwrap();
121
122    let port = InPort::new(&Rc::new(Entity::new(engine.top(), "port")), "test_rx");
123    store.connect_port_tx(port.state()).unwrap();
124
125    let check_inputs = inputs.to_owned();
126    engine.spawn(async move {
127        let mut store_get = vec![0; total_count];
128        for i in &mut store_get {
129            *i = port.get()?.await;
130        }
131
132        check_round_robin(&check_inputs, &store_get);
133        Ok(())
134    });
135}