1use std::future::Future;
9use std::sync::Arc;
10
11use tokio::sync::oneshot;
12
13pub struct BatchResponse<Resp> {
15 send: oneshot::Sender<Resp>,
16}
17
18impl<Resp> BatchResponse<Resp> {
19 #[must_use]
21 pub fn new(send: oneshot::Sender<Resp>) -> Self {
22 Self { send }
23 }
24
25 #[inline(always)]
27 pub fn send(self, response: Resp) {
28 let _ = self.send.send(response);
29 }
30
31 #[inline(always)]
33 pub fn send_ok<O, E>(self, response: O)
34 where
35 Resp: From<Result<O, E>>,
36 {
37 self.send(Ok(response).into())
38 }
39
40 #[inline(always)]
42 pub fn send_err<O, E>(self, error: E)
43 where
44 Resp: From<Result<O, E>>,
45 {
46 self.send(Err(error).into())
47 }
48
49 #[inline(always)]
51 pub fn send_none<T>(self)
52 where
53 Resp: From<Option<T>>,
54 {
55 self.send(None.into())
56 }
57
58 #[inline(always)]
60 pub fn send_some<T>(self, value: T)
61 where
62 Resp: From<Option<T>>,
63 {
64 self.send(Some(value).into())
65 }
66}
67
68pub trait BatchExecutor {
70 type Request: Send + 'static;
72 type Response: Send + Sync + 'static;
74
75 fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) -> impl Future<Output = ()> + Send;
79}
80
81#[derive(Clone, Copy, Debug)]
83#[must_use = "builders must be used to create a batcher"]
84pub struct BatcherBuilder<E> {
85 batch_size: usize,
86 concurrency: usize,
87 delay: std::time::Duration,
88 _marker: std::marker::PhantomData<E>,
89}
90
91impl<E> Default for BatcherBuilder<E> {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl<E> BatcherBuilder<E> {
98 pub const fn new() -> Self {
100 Self {
101 batch_size: 1000,
102 concurrency: 50,
103 delay: std::time::Duration::from_millis(5),
104 _marker: std::marker::PhantomData,
105 }
106 }
107
108 #[inline]
110 pub const fn batch_size(mut self, batch_size: usize) -> Self {
111 self.with_batch_size(batch_size);
112 self
113 }
114
115 #[inline]
117 pub const fn delay(mut self, delay: std::time::Duration) -> Self {
118 self.with_delay(delay);
119 self
120 }
121
122 #[inline]
124 pub const fn concurrency(mut self, concurrency: usize) -> Self {
125 self.with_concurrency(concurrency);
126 self
127 }
128
129 #[inline]
131 pub const fn with_concurrency(&mut self, concurrency: usize) -> &mut Self {
132 self.concurrency = concurrency;
133 self
134 }
135
136 #[inline]
138 pub const fn with_batch_size(&mut self, batch_size: usize) -> &mut Self {
139 self.batch_size = batch_size;
140 self
141 }
142
143 #[inline]
145 pub const fn with_delay(&mut self, delay: std::time::Duration) -> &mut Self {
146 self.delay = delay;
147 self
148 }
149
150 #[inline]
152 pub fn build(self, executor: E) -> Batcher<E>
153 where
154 E: BatchExecutor + Send + Sync + 'static,
155 {
156 Batcher::new(executor, self.batch_size, self.concurrency, self.delay)
157 }
158}
159
160#[must_use = "batchers must be used to execute batches"]
162pub struct Batcher<E>
163where
164 E: BatchExecutor + Send + Sync + 'static,
165{
166 _auto_spawn: tokio::task::JoinHandle<()>,
167 executor: Arc<E>,
168 semaphore: Arc<tokio::sync::Semaphore>,
169 current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
170 batch_size: usize,
171}
172
173struct Batch<E>
174where
175 E: BatchExecutor + Send + Sync + 'static,
176{
177 items: Vec<(E::Request, BatchResponse<E::Response>)>,
178 semaphore: Arc<tokio::sync::Semaphore>,
179 created_at: std::time::Instant,
180}
181
182impl<E> Batcher<E>
183where
184 E: BatchExecutor + Send + Sync + 'static,
185{
186 pub fn new(executor: E, batch_size: usize, concurrency: usize, delay: std::time::Duration) -> Self {
188 let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
189 let current_batch = Arc::new(tokio::sync::Mutex::new(None));
190 let executor = Arc::new(executor);
191
192 let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), delay));
193
194 Self {
195 executor,
196 _auto_spawn: join_handle,
197 semaphore,
198 current_batch,
199 batch_size: batch_size.max(1),
200 }
201 }
202
203 pub const fn builder() -> BatcherBuilder<E> {
205 BatcherBuilder::new()
206 }
207
208 pub async fn execute(&self, items: E::Request) -> Option<E::Response> {
210 self.execute_many(std::iter::once(items)).await.pop()?
211 }
212
213 pub async fn execute_many<I>(&self, items: I) -> Vec<Option<E::Response>>
215 where
216 I: IntoIterator<Item = E::Request>,
217 {
218 let mut responses = Vec::new();
219
220 {
221 let mut batch = self.current_batch.lock().await;
222
223 for item in items {
224 if batch.is_none() {
225 batch.replace(Batch::new(self.semaphore.clone()));
226 }
227
228 let batch_mut = batch.as_mut().unwrap();
229 let (tx, rx) = oneshot::channel();
230 batch_mut.items.push((item, BatchResponse::new(tx)));
231 responses.push(rx);
232
233 if batch_mut.items.len() >= self.batch_size {
234 tokio::spawn(batch.take().unwrap().spawn(self.executor.clone()));
235 }
236 }
237 }
238
239 let mut results = Vec::with_capacity(responses.len());
240 for response in responses {
241 results.push(response.await.ok());
242 }
243
244 results
245 }
246}
247
248async fn batch_loop<E>(
249 executor: Arc<E>,
250 current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
251 delay: std::time::Duration,
252) where
253 E: BatchExecutor + Send + Sync + 'static,
254{
255 let mut delay_delta = delay;
256 loop {
257 tokio::time::sleep(delay_delta).await;
258
259 let mut batch = current_batch.lock().await;
260 let Some(created_at) = batch.as_ref().map(|b| b.created_at) else {
261 delay_delta = delay;
262 continue;
263 };
264
265 let remaining = delay.saturating_sub(created_at.elapsed());
266 if remaining == std::time::Duration::ZERO {
267 tokio::spawn(batch.take().unwrap().spawn(executor.clone()));
268 delay_delta = delay;
269 } else {
270 delay_delta = remaining;
271 }
272 }
273}
274
275impl<E> Batch<E>
276where
277 E: BatchExecutor + Send + Sync + 'static,
278{
279 fn new(semaphore: Arc<tokio::sync::Semaphore>) -> Self {
280 Self {
281 created_at: std::time::Instant::now(),
282 items: Vec::new(),
283 semaphore,
284 }
285 }
286
287 async fn spawn(self, executor: Arc<E>) {
288 let _ticket = self.semaphore.acquire_owned().await;
289 executor.execute(self.items).await;
290 }
291}
292
293#[cfg_attr(all(coverage_nightly, test), coverage(off))]
300#[cfg(all(test, not(windows), not(target_os = "macos")))]
301mod tests {
302 use std::collections::HashMap;
303 use std::sync::atomic::AtomicUsize;
304
305 use super::*;
306
307 struct TestExecutor<K, V> {
308 values: HashMap<K, V>,
309 delay: std::time::Duration,
310 requests: Arc<AtomicUsize>,
311 capacity: usize,
312 }
313
314 impl<K, V> BatchExecutor for TestExecutor<K, V>
315 where
316 K: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
317 V: Clone + Send + Sync + 'static,
318 {
319 type Request = K;
320 type Response = V;
321
322 async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
323 tokio::time::sleep(self.delay).await;
324
325 assert!(requests.len() <= self.capacity);
326
327 self.requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
328 for (request, response) in requests {
329 if let Some(value) = self.values.get(&request) {
330 response.send(value.clone());
331 }
332 }
333 }
334 }
335
336 #[cfg(not(valgrind))] #[tokio::test]
338 async fn basic() {
339 let requests = Arc::new(AtomicUsize::new(0));
340
341 let fetcher = TestExecutor {
342 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
343 delay: std::time::Duration::from_millis(5),
344 requests: requests.clone(),
345 capacity: 2,
346 };
347
348 let loader = Batcher::builder().batch_size(2).concurrency(1).build(fetcher);
349
350 let start = std::time::Instant::now();
351 let a = loader.execute("a").await;
352 assert_eq!(a, Some(1));
353 assert!(start.elapsed() < std::time::Duration::from_millis(100));
354 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
355
356 let start = std::time::Instant::now();
357 let b = loader.execute("b").await;
358 assert_eq!(b, Some(2));
359 assert!(start.elapsed() < std::time::Duration::from_millis(100));
360 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
361 let start = std::time::Instant::now();
362 let c = loader.execute("c").await;
363 assert_eq!(c, Some(3));
364 assert!(start.elapsed() < std::time::Duration::from_millis(100));
365 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 3);
366
367 let start = std::time::Instant::now();
368 let ab = loader.execute_many(vec!["a", "b"]).await;
369 assert_eq!(ab, vec![Some(1), Some(2)]);
370 assert!(start.elapsed() < std::time::Duration::from_millis(100));
371 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 4);
372
373 let start = std::time::Instant::now();
374 let unknown = loader.execute("unknown").await;
375 assert_eq!(unknown, None);
376 assert!(start.elapsed() < std::time::Duration::from_millis(100));
377 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
378 }
379
380 #[cfg(not(valgrind))] #[tokio::test]
382 async fn concurrency_high() {
383 let requests = Arc::new(AtomicUsize::new(0));
384
385 let fetcher = TestExecutor {
386 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
387 delay: std::time::Duration::from_millis(5),
388 requests: requests.clone(),
389 capacity: 2,
390 };
391
392 let loader = Batcher::builder().batch_size(2).concurrency(10).build(fetcher);
393
394 let start = std::time::Instant::now();
395 let ab = loader
396 .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
397 .await;
398 assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
399 assert!(start.elapsed() < std::time::Duration::from_millis(100));
400 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
401 }
402
403 #[cfg(not(valgrind))] #[tokio::test]
405 async fn delay_low() {
406 let requests = Arc::new(AtomicUsize::new(0));
407
408 let fetcher = TestExecutor {
409 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
410 delay: std::time::Duration::from_millis(5),
411 requests: requests.clone(),
412 capacity: 2,
413 };
414
415 let loader = Batcher::builder()
416 .batch_size(2)
417 .concurrency(1)
418 .delay(std::time::Duration::from_millis(10))
419 .build(fetcher);
420
421 let start = std::time::Instant::now();
422 let ab = loader
423 .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
424 .await;
425 assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
426 assert!(start.elapsed() < std::time::Duration::from_millis(35));
427 assert!(start.elapsed() >= std::time::Duration::from_millis(25));
428 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
429 }
430
431 #[cfg(not(valgrind))] #[tokio::test]
433 async fn batch_size() {
434 let requests = Arc::new(AtomicUsize::new(0));
435
436 let fetcher = TestExecutor {
437 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
438 delay: std::time::Duration::from_millis(5),
439 requests: requests.clone(),
440 capacity: 100,
441 };
442
443 let loader = BatcherBuilder::default()
444 .batch_size(100)
445 .concurrency(1)
446 .delay(std::time::Duration::from_millis(10))
447 .build(fetcher);
448
449 let start = std::time::Instant::now();
450 let ab = loader
451 .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
452 .await;
453 assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
454 assert!(start.elapsed() >= std::time::Duration::from_millis(10));
455 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
456 }
457
458 #[cfg(not(valgrind))] #[tokio::test]
460 async fn high_concurrency() {
461 let requests = Arc::new(AtomicUsize::new(0));
462
463 let fetcher = TestExecutor {
464 values: HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))),
465 delay: std::time::Duration::from_millis(5),
466 requests: requests.clone(),
467 capacity: 100,
468 };
469
470 let loader = BatcherBuilder::default()
471 .batch_size(100)
472 .concurrency(10)
473 .delay(std::time::Duration::from_millis(10))
474 .build(fetcher);
475
476 let start = std::time::Instant::now();
477 let ab = loader.execute_many(0..1134).await;
478 assert_eq!(ab, (0..1134).map(|i| Some(i * 2 + 5)).collect::<Vec<_>>());
479 assert!(start.elapsed() >= std::time::Duration::from_millis(15));
480 assert!(start.elapsed() < std::time::Duration::from_millis(25));
481 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1134 / 100 + 1);
482 }
483
484 #[cfg(not(valgrind))] #[tokio::test]
486 async fn delayed_start() {
487 let requests = Arc::new(AtomicUsize::new(0));
488
489 let fetcher = TestExecutor {
490 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
491 delay: std::time::Duration::from_millis(5),
492 requests: requests.clone(),
493 capacity: 2,
494 };
495
496 let loader = BatcherBuilder::default()
497 .batch_size(2)
498 .concurrency(100)
499 .delay(std::time::Duration::from_millis(10))
500 .build(fetcher);
501
502 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
503
504 let start = std::time::Instant::now();
505 let ab = loader
506 .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
507 .await;
508 assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
509 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
510 assert!(start.elapsed() < std::time::Duration::from_millis(25));
511 }
512
513 #[cfg(not(valgrind))] #[tokio::test]
515 async fn delayed_start_single() {
516 let requests = Arc::new(AtomicUsize::new(0));
517
518 let fetcher = TestExecutor {
519 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
520 delay: std::time::Duration::from_millis(5),
521 requests: requests.clone(),
522 capacity: 2,
523 };
524
525 let loader = BatcherBuilder::default()
526 .batch_size(2)
527 .concurrency(100)
528 .delay(std::time::Duration::from_millis(10))
529 .build(fetcher);
530
531 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
532
533 let start = std::time::Instant::now();
534 let ab = loader.execute_many(vec!["a"]).await;
535 assert_eq!(ab, vec![Some(1)]);
536 assert!(start.elapsed() >= std::time::Duration::from_millis(15));
537 assert!(start.elapsed() < std::time::Duration::from_millis(20));
538 }
539
540 #[cfg(not(valgrind))] #[tokio::test]
542 async fn no_deduplication() {
543 let requests = Arc::new(AtomicUsize::new(0));
544
545 let fetcher = TestExecutor {
546 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
547 delay: std::time::Duration::from_millis(5),
548 requests: requests.clone(),
549 capacity: 4,
550 };
551
552 let loader = BatcherBuilder::default()
553 .batch_size(4)
554 .concurrency(1)
555 .delay(std::time::Duration::from_millis(10))
556 .build(fetcher);
557
558 let start = std::time::Instant::now();
559 let ab = loader.execute_many(vec!["a", "a", "b", "b", "c", "c"]).await;
560 assert_eq!(ab, vec![Some(1), Some(1), Some(2), Some(2), Some(3), Some(3)]);
561 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
562 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
563 assert!(start.elapsed() < std::time::Duration::from_millis(20));
564 }
565
566 #[cfg(not(valgrind))] #[tokio::test]
568 async fn result() {
569 let requests = Arc::new(AtomicUsize::new(0));
570
571 struct TestExecutor(Arc<AtomicUsize>);
572
573 impl BatchExecutor for TestExecutor {
574 type Request = &'static str;
575 type Response = Result<usize, ()>;
576
577 async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
578 self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
579 for (request, response) in requests {
580 match request.parse() {
581 Ok(value) => response.send_ok(value),
582 Err(_) => response.send_err(()),
583 }
584 }
585 }
586 }
587
588 let loader = BatcherBuilder::default()
589 .batch_size(4)
590 .concurrency(1)
591 .delay(std::time::Duration::from_millis(10))
592 .build(TestExecutor(requests.clone()));
593
594 let start = std::time::Instant::now();
595 let ab = loader.execute_many(vec!["1", "1", "2", "2", "3", "3", "hello"]).await;
596 assert_eq!(
597 ab,
598 vec![
599 Some(Ok(1)),
600 Some(Ok(1)),
601 Some(Ok(2)),
602 Some(Ok(2)),
603 Some(Ok(3)),
604 Some(Ok(3)),
605 Some(Err(()))
606 ]
607 );
608 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
609 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
610 assert!(start.elapsed() < std::time::Duration::from_millis(20));
611 }
612
613 #[cfg(not(valgrind))] #[tokio::test]
615 async fn option() {
616 let requests = Arc::new(AtomicUsize::new(0));
617
618 struct TestExecutor(Arc<AtomicUsize>);
619
620 impl BatchExecutor for TestExecutor {
621 type Request = &'static str;
622 type Response = Option<usize>;
623
624 async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
625 self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
626 for (request, response) in requests {
627 match request.parse() {
628 Ok(value) => response.send_some(value),
629 Err(_) => response.send_none(),
630 }
631 }
632 }
633 }
634
635 let loader = BatcherBuilder::default()
636 .batch_size(4)
637 .concurrency(1)
638 .delay(std::time::Duration::from_millis(10))
639 .build(TestExecutor(requests.clone()));
640
641 let start = std::time::Instant::now();
642 let ab = loader.execute_many(vec!["1", "1", "2", "2", "3", "3", "hello"]).await;
643 assert_eq!(
644 ab,
645 vec![
646 Some(Some(1)),
647 Some(Some(1)),
648 Some(Some(2)),
649 Some(Some(2)),
650 Some(Some(3)),
651 Some(Some(3)),
652 Some(None)
653 ]
654 );
655 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
656 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
657 assert!(start.elapsed() < std::time::Duration::from_millis(20));
658 }
659}