1use std::collections::{HashMap, HashSet};
6use std::future::Future;
7use std::sync::Arc;
8
9pub trait DataLoaderFetcher {
11 type Key: Clone + Eq + std::hash::Hash + Send + Sync;
13 type Value: Clone + Send + Sync;
15
16 fn load(&self, keys: HashSet<Self::Key>) -> impl Future<Output = Option<HashMap<Self::Key, Self::Value>>> + Send;
18}
19
20#[derive(Clone, Copy, Debug)]
22#[must_use = "builders must be used to create a dataloader"]
23pub struct DataLoaderBuilder<E> {
24 batch_size: usize,
25 concurrency: usize,
26 delay: std::time::Duration,
27 _phantom: std::marker::PhantomData<E>,
28}
29
30impl<E> Default for DataLoaderBuilder<E> {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl<E> DataLoaderBuilder<E> {
37 pub const fn new() -> Self {
39 Self {
40 batch_size: 1000,
41 concurrency: 50,
42 delay: std::time::Duration::from_millis(5),
43 _phantom: std::marker::PhantomData,
44 }
45 }
46
47 #[inline]
49 pub const fn batch_size(mut self, batch_size: usize) -> Self {
50 self.with_batch_size(batch_size);
51 self
52 }
53
54 #[inline]
56 pub const fn delay(mut self, delay: std::time::Duration) -> Self {
57 self.with_delay(delay);
58 self
59 }
60
61 #[inline]
63 pub const fn concurrency(mut self, concurrency: usize) -> Self {
64 self.with_concurrency(concurrency);
65 self
66 }
67
68 #[inline]
70 pub const fn with_batch_size(&mut self, batch_size: usize) -> &mut Self {
71 self.batch_size = batch_size;
72 self
73 }
74
75 #[inline]
77 pub const fn with_delay(&mut self, delay: std::time::Duration) -> &mut Self {
78 self.delay = delay;
79 self
80 }
81
82 #[inline]
84 pub const fn with_concurrency(&mut self, concurrency: usize) -> &mut Self {
85 self.concurrency = concurrency;
86 self
87 }
88
89 #[inline]
91 pub fn build(self, executor: E) -> DataLoader<E>
92 where
93 E: DataLoaderFetcher + Send + Sync + 'static,
94 {
95 DataLoader::new(executor, self.batch_size, self.concurrency, self.delay)
96 }
97}
98
99#[must_use = "dataloaders must be used to load data"]
101pub struct DataLoader<E>
102where
103 E: DataLoaderFetcher + Send + Sync + 'static,
104{
105 _auto_spawn: tokio::task::JoinHandle<()>,
106 executor: Arc<E>,
107 semaphore: Arc<tokio::sync::Semaphore>,
108 current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
109 batch_size: usize,
110}
111
112impl<E> DataLoader<E>
113where
114 E: DataLoaderFetcher + Send + Sync + 'static,
115{
116 pub fn new(executor: E, batch_size: usize, concurrency: usize, delay: std::time::Duration) -> Self {
118 let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
119 let current_batch = Arc::new(tokio::sync::Mutex::new(None));
120 let executor = Arc::new(executor);
121
122 let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), delay));
123
124 Self {
125 executor,
126 _auto_spawn: join_handle,
127 semaphore,
128 current_batch,
129 batch_size: batch_size.max(1),
130 }
131 }
132
133 #[inline]
135 pub const fn builder() -> DataLoaderBuilder<E> {
136 DataLoaderBuilder::new()
137 }
138
139 pub async fn load(&self, items: E::Key) -> Result<Option<E::Value>, ()> {
145 Ok(self.load_many(std::iter::once(items)).await?.into_values().next())
146 }
147
148 pub async fn load_many<I>(&self, items: I) -> Result<HashMap<E::Key, E::Value>, ()>
155 where
156 I: IntoIterator<Item = E::Key> + Send,
157 {
158 struct BatchWaiting<K, V> {
159 keys: HashSet<K>,
160 result: Arc<BatchResult<K, V>>,
161 }
162
163 let mut waiters = Vec::<BatchWaiting<E::Key, E::Value>>::new();
164
165 let mut count = 0;
166
167 {
168 let mut new_batch = true;
169 let mut batch = self.current_batch.lock().await;
170
171 for item in items {
172 if batch.is_none() {
173 batch.replace(Batch::new(self.semaphore.clone()));
174 new_batch = true;
175 }
176
177 let batch_mut = batch.as_mut().unwrap();
178 batch_mut.items.insert(item.clone());
179
180 if new_batch {
181 new_batch = false;
182 waiters.push(BatchWaiting {
183 keys: HashSet::new(),
184 result: batch_mut.result.clone(),
185 });
186 }
187
188 let waiting = waiters.last_mut().unwrap();
189 waiting.keys.insert(item);
190
191 count += 1;
192
193 if batch_mut.items.len() >= self.batch_size {
194 tokio::spawn(batch.take().unwrap().spawn(self.executor.clone()));
195 }
196 }
197 }
198
199 let mut results = HashMap::with_capacity(count);
200 for waiting in waiters {
201 let result = waiting.result.wait().await?;
202 results.extend(waiting.keys.into_iter().filter_map(|key| {
203 let value = result.get(&key)?.clone();
204 Some((key, value))
205 }));
206 }
207
208 Ok(results)
209 }
210}
211
212async fn batch_loop<E>(
213 executor: Arc<E>,
214 current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
215 delay: std::time::Duration,
216) where
217 E: DataLoaderFetcher + Send + Sync + 'static,
218{
219 let mut delay_delta = delay;
220 loop {
221 tokio::time::sleep(delay_delta).await;
222
223 let mut batch = current_batch.lock().await;
224 let Some(created_at) = batch.as_ref().map(|b| b.created_at) else {
225 delay_delta = delay;
226 continue;
227 };
228
229 let remaining = delay.saturating_sub(created_at.elapsed());
230 if remaining == std::time::Duration::ZERO {
231 tokio::spawn(batch.take().unwrap().spawn(executor.clone()));
232 delay_delta = delay;
233 } else {
234 delay_delta = remaining;
235 }
236 }
237}
238
239struct BatchResult<K, V> {
240 values: tokio::sync::OnceCell<Option<HashMap<K, V>>>,
241 token: tokio_util::sync::CancellationToken,
242}
243
244impl<K, V> BatchResult<K, V> {
245 fn new() -> Self {
246 Self {
247 values: tokio::sync::OnceCell::new(),
248 token: tokio_util::sync::CancellationToken::new(),
249 }
250 }
251
252 async fn wait(&self) -> Result<&HashMap<K, V>, ()> {
253 if !self.token.is_cancelled() {
254 self.token.cancelled().await;
255 }
256
257 self.values.get().ok_or(())?.as_ref().ok_or(())
258 }
259}
260
261struct Batch<E>
262where
263 E: DataLoaderFetcher + Send + Sync + 'static,
264{
265 items: HashSet<E::Key>,
266 result: Arc<BatchResult<E::Key, E::Value>>,
267 semaphore: Arc<tokio::sync::Semaphore>,
268 created_at: std::time::Instant,
269}
270
271impl<E> Batch<E>
272where
273 E: DataLoaderFetcher + Send + Sync + 'static,
274{
275 fn new(semaphore: Arc<tokio::sync::Semaphore>) -> Self {
276 Self {
277 items: HashSet::new(),
278 result: Arc::new(BatchResult::new()),
279 semaphore,
280 created_at: std::time::Instant::now(),
281 }
282 }
283
284 async fn spawn(self, executor: Arc<E>) {
285 let _drop_guard = self.result.token.clone().drop_guard();
286 let _ticket = self.semaphore.acquire_owned().await.unwrap();
287 let result = executor.load(self.items).await;
288
289 #[cfg_attr(all(coverage_nightly, test), coverage(off))]
290 fn unknwown_error<E>(_: E) -> ! {
291 unreachable!(
292 "batch result already set, this is a bug please report it https://github.com/scufflecloud/scuffle/issues"
293 )
294 }
295
296 self.result.values.set(result).map_err(unknwown_error).unwrap();
297 }
298}
299
300#[cfg_attr(all(coverage_nightly, test), coverage(off))]
307#[cfg(all(test, not(windows), not(target_os = "macos")))]
308mod tests {
309 use std::sync::atomic::AtomicUsize;
310
311 use super::*;
312
313 struct TestFetcher<K, V> {
314 values: HashMap<K, V>,
315 delay: std::time::Duration,
316 requests: Arc<AtomicUsize>,
317 capacity: usize,
318 }
319
320 impl<K, V> DataLoaderFetcher for TestFetcher<K, V>
321 where
322 K: Clone + Eq + std::hash::Hash + Send + Sync,
323 V: Clone + Send + Sync,
324 {
325 type Key = K;
326 type Value = V;
327
328 async fn load(&self, keys: HashSet<Self::Key>) -> Option<HashMap<Self::Key, Self::Value>> {
329 assert!(keys.len() <= self.capacity);
330 tokio::time::sleep(self.delay).await;
331 self.requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
332 Some(
333 keys.into_iter()
334 .filter_map(|k| {
335 let value = self.values.get(&k)?.clone();
336 Some((k, value))
337 })
338 .collect(),
339 )
340 }
341 }
342
343 #[cfg(not(valgrind))] #[tokio::test]
345 async fn basic() {
346 let requests = Arc::new(AtomicUsize::new(0));
347
348 let fetcher = TestFetcher {
349 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
350 delay: std::time::Duration::from_millis(5),
351 requests: requests.clone(),
352 capacity: 2,
353 };
354
355 let loader = DataLoader::builder().batch_size(2).concurrency(1).build(fetcher);
356
357 let start = std::time::Instant::now();
358 let a = loader.load("a").await.unwrap();
359 assert_eq!(a, Some(1));
360 assert!(start.elapsed() < std::time::Duration::from_millis(15));
361 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
362
363 let start = std::time::Instant::now();
364 let b = loader.load("b").await.unwrap();
365 assert_eq!(b, Some(2));
366 assert!(start.elapsed() < std::time::Duration::from_millis(15));
367 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
368 let start = std::time::Instant::now();
369 let c = loader.load("c").await.unwrap();
370 assert_eq!(c, Some(3));
371 assert!(start.elapsed() < std::time::Duration::from_millis(15));
372 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 3);
373
374 let start = std::time::Instant::now();
375 let ab = loader.load_many(vec!["a", "b"]).await.unwrap();
376 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2)]));
377 assert!(start.elapsed() < std::time::Duration::from_millis(15));
378 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 4);
379
380 let start = std::time::Instant::now();
381 let unknown = loader.load("unknown").await.unwrap();
382 assert_eq!(unknown, None);
383 assert!(start.elapsed() < std::time::Duration::from_millis(15));
384 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
385 }
386
387 #[cfg(not(valgrind))] #[tokio::test]
389 async fn concurrency_high() {
390 let requests = Arc::new(AtomicUsize::new(0));
391
392 let fetcher = TestFetcher {
393 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
394 delay: std::time::Duration::from_millis(5),
395 requests: requests.clone(),
396 capacity: 2,
397 };
398
399 let loader = DataLoader::builder().batch_size(2).concurrency(10).build(fetcher);
400
401 let start = std::time::Instant::now();
402 let ab = loader
403 .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
404 .await
405 .unwrap();
406 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
407 assert!(start.elapsed() < std::time::Duration::from_millis(15));
408 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
409 }
410
411 #[cfg(not(valgrind))] #[tokio::test]
413 async fn delay_low() {
414 let requests = Arc::new(AtomicUsize::new(0));
415
416 let fetcher = TestFetcher {
417 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
418 delay: std::time::Duration::from_millis(5),
419 requests: requests.clone(),
420 capacity: 2,
421 };
422
423 let loader = DataLoader::builder()
424 .batch_size(2)
425 .concurrency(1)
426 .delay(std::time::Duration::from_millis(10))
427 .build(fetcher);
428
429 let start = std::time::Instant::now();
430 let ab = loader
431 .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
432 .await
433 .unwrap();
434 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
435 assert!(start.elapsed() < std::time::Duration::from_millis(35));
436 assert!(start.elapsed() >= std::time::Duration::from_millis(25));
437 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
438 }
439
440 #[cfg(not(valgrind))] #[tokio::test]
442 async fn batch_size() {
443 let requests = Arc::new(AtomicUsize::new(0));
444
445 let fetcher = TestFetcher {
446 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
447 delay: std::time::Duration::from_millis(5),
448 requests: requests.clone(),
449 capacity: 100,
450 };
451
452 let loader = DataLoaderBuilder::default()
453 .batch_size(100)
454 .concurrency(1)
455 .delay(std::time::Duration::from_millis(10))
456 .build(fetcher);
457
458 let start = std::time::Instant::now();
459 let ab = loader
460 .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
461 .await
462 .unwrap();
463 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
464 assert!(start.elapsed() >= std::time::Duration::from_millis(10));
465 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
466 }
467
468 #[cfg(not(valgrind))] #[tokio::test]
470 async fn high_concurrency() {
471 let requests = Arc::new(AtomicUsize::new(0));
472
473 let fetcher = TestFetcher {
474 values: HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))),
475 delay: std::time::Duration::from_millis(5),
476 requests: requests.clone(),
477 capacity: 100,
478 };
479
480 let loader = DataLoaderBuilder::default()
481 .batch_size(100)
482 .concurrency(10)
483 .delay(std::time::Duration::from_millis(10))
484 .build(fetcher);
485
486 let start = std::time::Instant::now();
487 let ab = loader.load_many(0..1134).await.unwrap();
488 assert_eq!(ab, HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))));
489 assert!(start.elapsed() >= std::time::Duration::from_millis(15));
490 assert!(start.elapsed() < std::time::Duration::from_millis(25));
491 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1134 / 100 + 1);
492 }
493
494 #[cfg(not(valgrind))] #[tokio::test]
496 async fn delayed_start() {
497 let requests = Arc::new(AtomicUsize::new(0));
498
499 let fetcher = TestFetcher {
500 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
501 delay: std::time::Duration::from_millis(5),
502 requests: requests.clone(),
503 capacity: 2,
504 };
505
506 let loader = DataLoader::builder()
507 .batch_size(2)
508 .concurrency(100)
509 .delay(std::time::Duration::from_millis(10))
510 .build(fetcher);
511
512 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
513
514 let start = std::time::Instant::now();
515 let ab = loader
516 .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
517 .await
518 .unwrap();
519 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
520 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
521 assert!(start.elapsed() < std::time::Duration::from_millis(25));
522 }
523
524 #[cfg(not(valgrind))] #[tokio::test]
526 async fn delayed_start_single() {
527 let requests = Arc::new(AtomicUsize::new(0));
528
529 let fetcher = TestFetcher {
530 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
531 delay: std::time::Duration::from_millis(5),
532 requests: requests.clone(),
533 capacity: 2,
534 };
535
536 let loader = DataLoader::builder()
537 .batch_size(2)
538 .concurrency(100)
539 .delay(std::time::Duration::from_millis(10))
540 .build(fetcher);
541
542 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
543
544 let start = std::time::Instant::now();
545 let ab = loader.load_many(vec!["a"]).await.unwrap();
546 assert_eq!(ab, HashMap::from_iter(vec![("a", 1)]));
547 assert!(start.elapsed() >= std::time::Duration::from_millis(15));
548 assert!(start.elapsed() < std::time::Duration::from_millis(20));
549 }
550
551 #[cfg(not(valgrind))] #[tokio::test]
553 async fn deduplication() {
554 let requests = Arc::new(AtomicUsize::new(0));
555
556 let fetcher = TestFetcher {
557 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
558 delay: std::time::Duration::from_millis(5),
559 requests: requests.clone(),
560 capacity: 4,
561 };
562
563 let loader = DataLoader::builder()
564 .batch_size(4)
565 .concurrency(1)
566 .delay(std::time::Duration::from_millis(10))
567 .build(fetcher);
568
569 let start = std::time::Instant::now();
570 let ab = loader.load_many(vec!["a", "a", "b", "b", "c", "c"]).await.unwrap();
571 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
572 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
573 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
574 assert!(start.elapsed() < std::time::Duration::from_millis(20));
575 }
576
577 #[cfg(not(valgrind))] #[tokio::test]
579 async fn already_batch() {
580 let requests = Arc::new(AtomicUsize::new(0));
581
582 let fetcher = TestFetcher {
583 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
584 delay: std::time::Duration::from_millis(5),
585 requests: requests.clone(),
586 capacity: 2,
587 };
588
589 let loader = DataLoader::builder().batch_size(10).concurrency(1).build(fetcher);
590
591 let start = std::time::Instant::now();
592 let (a, b) = tokio::join!(loader.load("a"), loader.load("b"));
593 assert_eq!(a, Ok(Some(1)));
594 assert_eq!(b, Ok(Some(2)));
595 assert!(start.elapsed() < std::time::Duration::from_millis(15));
596 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
597 }
598}