1use std::marker::PhantomData;
2
3use base64::Engine;
4use bytes::Bytes;
5
6use super::{DeserializeContent, DeserializeHelper, Expected, Tracker, TrackerDeserializer, TrackerFor};
7
8pub struct BytesTracker<T>(PhantomData<T>);
9
10impl<T> std::fmt::Debug for BytesTracker<T> {
11 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
12 write!(f, "BytesTracker<{}>", std::any::type_name::<T>())
13 }
14}
15
16pub trait BytesLikeTracker: Tracker {
17 fn set_target(&mut self, target: &mut Self::Target, buf: impl bytes::Buf);
18
19 fn set_target_vec(&mut self, target: &mut Self::Target, data: Vec<u8>) {
20 self.set_target(target, data.as_slice());
21 }
22}
23
24impl BytesLikeTracker for BytesTracker<Bytes> {
25 fn set_target(&mut self, target: &mut Self::Target, mut buf: impl bytes::Buf) {
26 *target = buf.copy_to_bytes(buf.remaining());
27 }
28}
29impl BytesLikeTracker for BytesTracker<Vec<u8>> {
30 fn set_target(&mut self, target: &mut Self::Target, mut buf: impl bytes::Buf) {
31 target.clear();
32 target.reserve_exact(buf.remaining());
33 while buf.has_remaining() {
34 let chunk = buf.chunk();
35 target.extend_from_slice(chunk);
36 buf.advance(chunk.len());
37 }
38 }
39
40 fn set_target_vec(&mut self, target: &mut Self::Target, data: Vec<u8>) {
41 *target = data;
42 }
43}
44
45impl<T> Default for BytesTracker<T> {
46 fn default() -> Self {
47 BytesTracker(PhantomData)
48 }
49}
50
51impl<T: Expected> Tracker for BytesTracker<T> {
52 type Target = T;
53
54 fn allow_duplicates(&self) -> bool {
55 false
56 }
57}
58
59impl TrackerFor for Vec<u8> {
60 type Tracker = BytesTracker<Vec<u8>>;
61}
62
63impl Expected for Vec<u8> {
64 fn expecting(formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
65 write!(formatter, "bytes")
66 }
67}
68
69impl TrackerFor for bytes::Bytes {
70 type Tracker = BytesTracker<Self>;
71}
72
73impl Expected for bytes::Bytes {
74 fn expecting(formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
75 write!(formatter, "bytes")
76 }
77}
78
79impl<'de, T> serde::de::DeserializeSeed<'de> for DeserializeHelper<'_, BytesTracker<T>>
80where
81 T: Expected,
82 BytesTracker<T>: Tracker<Target = T> + BytesLikeTracker,
83{
84 type Value = ();
85
86 fn deserialize<D>(self, de: D) -> Result<Self::Value, D::Error>
87 where
88 D: serde::Deserializer<'de>,
89 {
90 de.deserialize_str(self)
91 }
92}
93
94impl<'de, T> serde::de::Visitor<'de> for DeserializeHelper<'_, BytesTracker<T>>
95where
96 T: Expected,
97 BytesTracker<T>: Tracker<Target = T> + BytesLikeTracker,
98{
99 type Value = ();
100
101 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
102 T::expecting(formatter)
103 }
104
105 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
106 where
107 E: serde::de::Error,
108 {
109 let config = base64::engine::GeneralPurposeConfig::new()
110 .with_decode_allow_trailing_bits(true)
111 .with_encode_padding(true)
112 .with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent);
113
114 let alphabet = if v.as_bytes().iter().any(|b| b == &b'-' || b == &b'_') {
115 &base64::alphabet::URL_SAFE
116 } else {
117 &base64::alphabet::STANDARD
118 };
119
120 let engine = base64::engine::GeneralPurpose::new(alphabet, config);
121 let bytes = engine.decode(v.as_bytes()).map_err(serde::de::Error::custom)?;
122 self.tracker.set_target_vec(self.value, bytes);
123 Ok(())
124 }
125}
126
127impl<'de, T> TrackerDeserializer<'de> for BytesTracker<T>
128where
129 T: Expected,
130 BytesTracker<T>: Tracker<Target = T> + BytesLikeTracker,
131{
132 fn deserialize<D>(&mut self, value: &mut Self::Target, deserializer: D) -> Result<(), D::Error>
133 where
134 D: DeserializeContent<'de>,
135 {
136 deserializer.deserialize_seed(DeserializeHelper { value, tracker: self })
137 }
138}