scuffle_batching/
batch.rs

1//! Types related to the batcher and batch executor.
2//!
3//! The [`Batcher`] is used to batch requests to a [`BatchExecutor`].
4//!
5//! This shouldn't be used to fetch data.
6//! Please refer to the [`dataloader`](crate::dataloader) module for that.
7
8use std::future::Future;
9use std::sync::Arc;
10
11use tokio::sync::oneshot;
12
13/// A response to a batch request
14pub struct BatchResponse<Resp> {
15    send: oneshot::Sender<Resp>,
16}
17
18impl<Resp> BatchResponse<Resp> {
19    /// Create a new batch response
20    #[must_use]
21    pub fn new(send: oneshot::Sender<Resp>) -> Self {
22        Self { send }
23    }
24
25    /// Send a response back to the requester
26    #[inline(always)]
27    pub fn send(self, response: Resp) {
28        let _ = self.send.send(response);
29    }
30
31    /// Send a successful response back to the requester
32    #[inline(always)]
33    pub fn send_ok<O, E>(self, response: O)
34    where
35        Resp: From<Result<O, E>>,
36    {
37        self.send(Ok(response).into())
38    }
39
40    /// Send an error response back to the requestor
41    #[inline(always)]
42    pub fn send_err<O, E>(self, error: E)
43    where
44        Resp: From<Result<O, E>>,
45    {
46        self.send(Err(error).into())
47    }
48
49    /// Send a `None` response back to the requestor
50    #[inline(always)]
51    pub fn send_none<T>(self)
52    where
53        Resp: From<Option<T>>,
54    {
55        self.send(None.into())
56    }
57
58    /// Send a value response back to the requestor
59    #[inline(always)]
60    pub fn send_some<T>(self, value: T)
61    where
62        Resp: From<Option<T>>,
63    {
64        self.send(Some(value).into())
65    }
66}
67
68/// A trait for executing batches
69pub trait BatchExecutor {
70    /// The incoming request type
71    type Request: Send + 'static;
72    /// The outgoing response type
73    type Response: Send + Sync + 'static;
74
75    /// Execute a batch of requests
76    /// You must call `send` on the `BatchResponse` to send the response back to
77    /// the client
78    fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) -> impl Future<Output = ()> + Send;
79}
80
81/// A builder for a [`Batcher`]
82#[derive(Clone, Copy, Debug)]
83#[must_use = "builders must be used to create a batcher"]
84pub struct BatcherBuilder<E> {
85    batch_size: usize,
86    concurrency: usize,
87    delay: std::time::Duration,
88    _marker: std::marker::PhantomData<E>,
89}
90
91impl<E> Default for BatcherBuilder<E> {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl<E> BatcherBuilder<E> {
98    /// Create a new builder
99    pub const fn new() -> Self {
100        Self {
101            batch_size: 1000,
102            concurrency: 50,
103            delay: std::time::Duration::from_millis(5),
104            _marker: std::marker::PhantomData,
105        }
106    }
107
108    /// Set the batch size
109    #[inline]
110    pub const fn batch_size(mut self, batch_size: usize) -> Self {
111        self.with_batch_size(batch_size);
112        self
113    }
114
115    /// Set the delay
116    #[inline]
117    pub const fn delay(mut self, delay: std::time::Duration) -> Self {
118        self.with_delay(delay);
119        self
120    }
121
122    /// Set the concurrency to 1
123    #[inline]
124    pub const fn concurrency(mut self, concurrency: usize) -> Self {
125        self.with_concurrency(concurrency);
126        self
127    }
128
129    /// Set the concurrency
130    #[inline]
131    pub const fn with_concurrency(&mut self, concurrency: usize) -> &mut Self {
132        self.concurrency = concurrency;
133        self
134    }
135
136    /// Set the batch size
137    #[inline]
138    pub const fn with_batch_size(&mut self, batch_size: usize) -> &mut Self {
139        self.batch_size = batch_size;
140        self
141    }
142
143    /// Set the delay
144    #[inline]
145    pub const fn with_delay(&mut self, delay: std::time::Duration) -> &mut Self {
146        self.delay = delay;
147        self
148    }
149
150    /// Build the batcher
151    #[inline]
152    pub fn build(self, executor: E) -> Batcher<E>
153    where
154        E: BatchExecutor + Send + Sync + 'static,
155    {
156        Batcher::new(executor, self.batch_size, self.concurrency, self.delay)
157    }
158}
159
160/// A batcher used to batch requests to a [`BatchExecutor`]
161#[must_use = "batchers must be used to execute batches"]
162pub struct Batcher<E>
163where
164    E: BatchExecutor + Send + Sync + 'static,
165{
166    _auto_spawn: tokio::task::JoinHandle<()>,
167    executor: Arc<E>,
168    semaphore: Arc<tokio::sync::Semaphore>,
169    current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
170    batch_size: usize,
171}
172
173struct Batch<E>
174where
175    E: BatchExecutor + Send + Sync + 'static,
176{
177    items: Vec<(E::Request, BatchResponse<E::Response>)>,
178    semaphore: Arc<tokio::sync::Semaphore>,
179    created_at: std::time::Instant,
180}
181
182impl<E> Batcher<E>
183where
184    E: BatchExecutor + Send + Sync + 'static,
185{
186    /// Create a new batcher
187    pub fn new(executor: E, batch_size: usize, concurrency: usize, delay: std::time::Duration) -> Self {
188        let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
189        let current_batch = Arc::new(tokio::sync::Mutex::new(None));
190        let executor = Arc::new(executor);
191
192        let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), delay));
193
194        Self {
195            executor,
196            _auto_spawn: join_handle,
197            semaphore,
198            current_batch,
199            batch_size: batch_size.max(1),
200        }
201    }
202
203    /// Create a builder for a [`Batcher`]
204    pub const fn builder() -> BatcherBuilder<E> {
205        BatcherBuilder::new()
206    }
207
208    /// Execute a single request
209    pub async fn execute(&self, items: E::Request) -> Option<E::Response> {
210        self.execute_many(std::iter::once(items)).await.pop()?
211    }
212
213    /// Execute many requests
214    pub async fn execute_many<I>(&self, items: I) -> Vec<Option<E::Response>>
215    where
216        I: IntoIterator<Item = E::Request>,
217    {
218        let mut responses = Vec::new();
219
220        {
221            let mut batch = self.current_batch.lock().await;
222
223            for item in items {
224                if batch.is_none() {
225                    batch.replace(Batch::new(self.semaphore.clone()));
226                }
227
228                let batch_mut = batch.as_mut().unwrap();
229                let (tx, rx) = oneshot::channel();
230                batch_mut.items.push((item, BatchResponse::new(tx)));
231                responses.push(rx);
232
233                if batch_mut.items.len() >= self.batch_size {
234                    tokio::spawn(batch.take().unwrap().spawn(self.executor.clone()));
235                }
236            }
237        }
238
239        let mut results = Vec::with_capacity(responses.len());
240        for response in responses {
241            results.push(response.await.ok());
242        }
243
244        results
245    }
246}
247
248async fn batch_loop<E>(
249    executor: Arc<E>,
250    current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
251    delay: std::time::Duration,
252) where
253    E: BatchExecutor + Send + Sync + 'static,
254{
255    let mut delay_delta = delay;
256    loop {
257        tokio::time::sleep(delay_delta).await;
258
259        let mut batch = current_batch.lock().await;
260        let Some(created_at) = batch.as_ref().map(|b| b.created_at) else {
261            delay_delta = delay;
262            continue;
263        };
264
265        let remaining = delay.saturating_sub(created_at.elapsed());
266        if remaining == std::time::Duration::ZERO {
267            tokio::spawn(batch.take().unwrap().spawn(executor.clone()));
268            delay_delta = delay;
269        } else {
270            delay_delta = remaining;
271        }
272    }
273}
274
275impl<E> Batch<E>
276where
277    E: BatchExecutor + Send + Sync + 'static,
278{
279    fn new(semaphore: Arc<tokio::sync::Semaphore>) -> Self {
280        Self {
281            created_at: std::time::Instant::now(),
282            items: Vec::new(),
283            semaphore,
284        }
285    }
286
287    async fn spawn(self, executor: Arc<E>) {
288        let _ticket = self.semaphore.acquire_owned().await;
289        executor.execute(self.items).await;
290    }
291}
292
293/// TODO: Windows is disabled because i suspect windows doesnt measure time precisely
294/// enough to test the time-sensitive tests.
295/// We should fix this and re-enable the tests.
296/// Similar issue with macos, but macos is disabled because it is too slow
297/// in CI and the tests fail due to timeouts.
298/// CLOUD-74
299#[cfg_attr(all(coverage_nightly, test), coverage(off))]
300#[cfg(all(test, not(windows), not(target_os = "macos")))]
301mod tests {
302    use std::collections::HashMap;
303    use std::sync::atomic::AtomicUsize;
304
305    use super::*;
306
307    struct TestExecutor<K, V> {
308        values: HashMap<K, V>,
309        delay: std::time::Duration,
310        requests: Arc<AtomicUsize>,
311        capacity: usize,
312    }
313
314    impl<K, V> BatchExecutor for TestExecutor<K, V>
315    where
316        K: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
317        V: Clone + Send + Sync + 'static,
318    {
319        type Request = K;
320        type Response = V;
321
322        async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
323            tokio::time::sleep(self.delay).await;
324
325            assert!(requests.len() <= self.capacity);
326
327            self.requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
328            for (request, response) in requests {
329                if let Some(value) = self.values.get(&request) {
330                    response.send(value.clone());
331                }
332            }
333        }
334    }
335
336    #[cfg(not(valgrind))] // test is time-sensitive
337    #[tokio::test]
338    async fn basic() {
339        let requests = Arc::new(AtomicUsize::new(0));
340
341        let fetcher = TestExecutor {
342            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
343            delay: std::time::Duration::from_millis(5),
344            requests: requests.clone(),
345            capacity: 2,
346        };
347
348        let loader = Batcher::builder().batch_size(2).concurrency(1).build(fetcher);
349
350        let start = std::time::Instant::now();
351        let a = loader.execute("a").await;
352        assert_eq!(a, Some(1));
353        assert!(start.elapsed() < std::time::Duration::from_millis(100));
354        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
355
356        let start = std::time::Instant::now();
357        let b = loader.execute("b").await;
358        assert_eq!(b, Some(2));
359        assert!(start.elapsed() < std::time::Duration::from_millis(100));
360        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
361        let start = std::time::Instant::now();
362        let c = loader.execute("c").await;
363        assert_eq!(c, Some(3));
364        assert!(start.elapsed() < std::time::Duration::from_millis(100));
365        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 3);
366
367        let start = std::time::Instant::now();
368        let ab = loader.execute_many(vec!["a", "b"]).await;
369        assert_eq!(ab, vec![Some(1), Some(2)]);
370        assert!(start.elapsed() < std::time::Duration::from_millis(100));
371        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 4);
372
373        let start = std::time::Instant::now();
374        let unknown = loader.execute("unknown").await;
375        assert_eq!(unknown, None);
376        assert!(start.elapsed() < std::time::Duration::from_millis(100));
377        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
378    }
379
380    #[cfg(not(valgrind))] // test is time-sensitive
381    #[tokio::test]
382    async fn concurrency_high() {
383        let requests = Arc::new(AtomicUsize::new(0));
384
385        let fetcher = TestExecutor {
386            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
387            delay: std::time::Duration::from_millis(5),
388            requests: requests.clone(),
389            capacity: 2,
390        };
391
392        let loader = Batcher::builder().batch_size(2).concurrency(10).build(fetcher);
393
394        let start = std::time::Instant::now();
395        let ab = loader
396            .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
397            .await;
398        assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
399        assert!(start.elapsed() < std::time::Duration::from_millis(100));
400        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
401    }
402
403    #[cfg(not(valgrind))] // test is time-sensitive
404    #[tokio::test]
405    async fn delay_low() {
406        let requests = Arc::new(AtomicUsize::new(0));
407
408        let fetcher = TestExecutor {
409            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
410            delay: std::time::Duration::from_millis(5),
411            requests: requests.clone(),
412            capacity: 2,
413        };
414
415        let loader = Batcher::builder()
416            .batch_size(2)
417            .concurrency(1)
418            .delay(std::time::Duration::from_millis(10))
419            .build(fetcher);
420
421        let start = std::time::Instant::now();
422        let ab = loader
423            .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
424            .await;
425        assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
426        assert!(start.elapsed() < std::time::Duration::from_millis(35));
427        assert!(start.elapsed() >= std::time::Duration::from_millis(25));
428        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
429    }
430
431    #[cfg(not(valgrind))] // test is time-sensitive
432    #[tokio::test]
433    async fn batch_size() {
434        let requests = Arc::new(AtomicUsize::new(0));
435
436        let fetcher = TestExecutor {
437            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
438            delay: std::time::Duration::from_millis(5),
439            requests: requests.clone(),
440            capacity: 100,
441        };
442
443        let loader = BatcherBuilder::default()
444            .batch_size(100)
445            .concurrency(1)
446            .delay(std::time::Duration::from_millis(10))
447            .build(fetcher);
448
449        let start = std::time::Instant::now();
450        let ab = loader
451            .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
452            .await;
453        assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
454        assert!(start.elapsed() >= std::time::Duration::from_millis(10));
455        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
456    }
457
458    #[cfg(not(valgrind))] // test is time-sensitive
459    #[tokio::test]
460    async fn high_concurrency() {
461        let requests = Arc::new(AtomicUsize::new(0));
462
463        let fetcher = TestExecutor {
464            values: HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))),
465            delay: std::time::Duration::from_millis(5),
466            requests: requests.clone(),
467            capacity: 100,
468        };
469
470        let loader = BatcherBuilder::default()
471            .batch_size(100)
472            .concurrency(10)
473            .delay(std::time::Duration::from_millis(10))
474            .build(fetcher);
475
476        let start = std::time::Instant::now();
477        let ab = loader.execute_many(0..1134).await;
478        assert_eq!(ab, (0..1134).map(|i| Some(i * 2 + 5)).collect::<Vec<_>>());
479        assert!(start.elapsed() >= std::time::Duration::from_millis(15));
480        assert!(start.elapsed() < std::time::Duration::from_millis(25));
481        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1134 / 100 + 1);
482    }
483
484    #[cfg(not(valgrind))] // test is time-sensitive
485    #[tokio::test]
486    async fn delayed_start() {
487        let requests = Arc::new(AtomicUsize::new(0));
488
489        let fetcher = TestExecutor {
490            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
491            delay: std::time::Duration::from_millis(5),
492            requests: requests.clone(),
493            capacity: 2,
494        };
495
496        let loader = BatcherBuilder::default()
497            .batch_size(2)
498            .concurrency(100)
499            .delay(std::time::Duration::from_millis(10))
500            .build(fetcher);
501
502        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
503
504        let start = std::time::Instant::now();
505        let ab = loader
506            .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
507            .await;
508        assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
509        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
510        assert!(start.elapsed() < std::time::Duration::from_millis(25));
511    }
512
513    #[cfg(not(valgrind))] // test is time-sensitive
514    #[tokio::test]
515    async fn delayed_start_single() {
516        let requests = Arc::new(AtomicUsize::new(0));
517
518        let fetcher = TestExecutor {
519            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
520            delay: std::time::Duration::from_millis(5),
521            requests: requests.clone(),
522            capacity: 2,
523        };
524
525        let loader = BatcherBuilder::default()
526            .batch_size(2)
527            .concurrency(100)
528            .delay(std::time::Duration::from_millis(10))
529            .build(fetcher);
530
531        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
532
533        let start = std::time::Instant::now();
534        let ab = loader.execute_many(vec!["a"]).await;
535        assert_eq!(ab, vec![Some(1)]);
536        assert!(start.elapsed() >= std::time::Duration::from_millis(15));
537        assert!(start.elapsed() < std::time::Duration::from_millis(20));
538    }
539
540    #[cfg(not(valgrind))] // test is time-sensitive
541    #[tokio::test]
542    async fn no_deduplication() {
543        let requests = Arc::new(AtomicUsize::new(0));
544
545        let fetcher = TestExecutor {
546            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
547            delay: std::time::Duration::from_millis(5),
548            requests: requests.clone(),
549            capacity: 4,
550        };
551
552        let loader = BatcherBuilder::default()
553            .batch_size(4)
554            .concurrency(1)
555            .delay(std::time::Duration::from_millis(10))
556            .build(fetcher);
557
558        let start = std::time::Instant::now();
559        let ab = loader.execute_many(vec!["a", "a", "b", "b", "c", "c"]).await;
560        assert_eq!(ab, vec![Some(1), Some(1), Some(2), Some(2), Some(3), Some(3)]);
561        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
562        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
563        assert!(start.elapsed() < std::time::Duration::from_millis(20));
564    }
565
566    #[cfg(not(valgrind))] // test is time-sensitive
567    #[tokio::test]
568    async fn result() {
569        let requests = Arc::new(AtomicUsize::new(0));
570
571        struct TestExecutor(Arc<AtomicUsize>);
572
573        impl BatchExecutor for TestExecutor {
574            type Request = &'static str;
575            type Response = Result<usize, ()>;
576
577            async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
578                self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
579                for (request, response) in requests {
580                    match request.parse() {
581                        Ok(value) => response.send_ok(value),
582                        Err(_) => response.send_err(()),
583                    }
584                }
585            }
586        }
587
588        let loader = BatcherBuilder::default()
589            .batch_size(4)
590            .concurrency(1)
591            .delay(std::time::Duration::from_millis(10))
592            .build(TestExecutor(requests.clone()));
593
594        let start = std::time::Instant::now();
595        let ab = loader.execute_many(vec!["1", "1", "2", "2", "3", "3", "hello"]).await;
596        assert_eq!(
597            ab,
598            vec![
599                Some(Ok(1)),
600                Some(Ok(1)),
601                Some(Ok(2)),
602                Some(Ok(2)),
603                Some(Ok(3)),
604                Some(Ok(3)),
605                Some(Err(()))
606            ]
607        );
608        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
609        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
610        assert!(start.elapsed() < std::time::Duration::from_millis(20));
611    }
612
613    #[cfg(not(valgrind))] // test is time-sensitive
614    #[tokio::test]
615    async fn option() {
616        let requests = Arc::new(AtomicUsize::new(0));
617
618        struct TestExecutor(Arc<AtomicUsize>);
619
620        impl BatchExecutor for TestExecutor {
621            type Request = &'static str;
622            type Response = Option<usize>;
623
624            async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
625                self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
626                for (request, response) in requests {
627                    match request.parse() {
628                        Ok(value) => response.send_some(value),
629                        Err(_) => response.send_none(),
630                    }
631                }
632            }
633        }
634
635        let loader = BatcherBuilder::default()
636            .batch_size(4)
637            .concurrency(1)
638            .delay(std::time::Duration::from_millis(10))
639            .build(TestExecutor(requests.clone()));
640
641        let start = std::time::Instant::now();
642        let ab = loader.execute_many(vec!["1", "1", "2", "2", "3", "3", "hello"]).await;
643        assert_eq!(
644            ab,
645            vec![
646                Some(Some(1)),
647                Some(Some(1)),
648                Some(Some(2)),
649                Some(Some(2)),
650                Some(Some(3)),
651                Some(Some(3)),
652                Some(None)
653            ]
654        );
655        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
656        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
657        assert!(start.elapsed() < std::time::Duration::from_millis(20));
658    }
659}