tinc/private/
bytes.rs

1use std::marker::PhantomData;
2
3use base64::Engine;
4use bytes::Bytes;
5
6use super::{DeserializeContent, DeserializeHelper, Expected, Tracker, TrackerDeserializer, TrackerFor};
7
8pub struct BytesTracker<T>(PhantomData<T>);
9
10impl<T> std::fmt::Debug for BytesTracker<T> {
11    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
12        write!(f, "BytesTracker<{}>", std::any::type_name::<T>())
13    }
14}
15
16pub trait BytesLikeTracker: Tracker {
17    fn set_target(&mut self, target: &mut Self::Target, buf: impl bytes::Buf);
18
19    fn set_target_vec(&mut self, target: &mut Self::Target, data: Vec<u8>) {
20        self.set_target(target, data.as_slice());
21    }
22}
23
24impl BytesLikeTracker for BytesTracker<Bytes> {
25    fn set_target(&mut self, target: &mut Self::Target, mut buf: impl bytes::Buf) {
26        *target = buf.copy_to_bytes(buf.remaining());
27    }
28}
29impl BytesLikeTracker for BytesTracker<Vec<u8>> {
30    fn set_target(&mut self, target: &mut Self::Target, mut buf: impl bytes::Buf) {
31        target.clear();
32        target.reserve_exact(buf.remaining());
33        while buf.has_remaining() {
34            let chunk = buf.chunk();
35            target.extend_from_slice(chunk);
36            buf.advance(chunk.len());
37        }
38    }
39
40    fn set_target_vec(&mut self, target: &mut Self::Target, data: Vec<u8>) {
41        *target = data;
42    }
43}
44
45impl<T> Default for BytesTracker<T> {
46    fn default() -> Self {
47        BytesTracker(PhantomData)
48    }
49}
50
51impl<T: Expected> Tracker for BytesTracker<T> {
52    type Target = T;
53
54    fn allow_duplicates(&self) -> bool {
55        false
56    }
57}
58
59impl TrackerFor for Vec<u8> {
60    type Tracker = BytesTracker<Vec<u8>>;
61}
62
63impl Expected for Vec<u8> {
64    fn expecting(formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
65        write!(formatter, "bytes")
66    }
67}
68
69impl TrackerFor for bytes::Bytes {
70    type Tracker = BytesTracker<Self>;
71}
72
73impl Expected for bytes::Bytes {
74    fn expecting(formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
75        write!(formatter, "bytes")
76    }
77}
78
79impl<'de, T> serde::de::DeserializeSeed<'de> for DeserializeHelper<'_, BytesTracker<T>>
80where
81    T: Expected,
82    BytesTracker<T>: Tracker<Target = T> + BytesLikeTracker,
83{
84    type Value = ();
85
86    fn deserialize<D>(self, de: D) -> Result<Self::Value, D::Error>
87    where
88        D: serde::Deserializer<'de>,
89    {
90        de.deserialize_str(self)
91    }
92}
93
94impl<'de, T> serde::de::Visitor<'de> for DeserializeHelper<'_, BytesTracker<T>>
95where
96    T: Expected,
97    BytesTracker<T>: Tracker<Target = T> + BytesLikeTracker,
98{
99    type Value = ();
100
101    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
102        T::expecting(formatter)
103    }
104
105    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
106    where
107        E: serde::de::Error,
108    {
109        let config = base64::engine::GeneralPurposeConfig::new()
110            .with_decode_allow_trailing_bits(true)
111            .with_encode_padding(true)
112            .with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent);
113
114        let alphabet = if v.as_bytes().iter().any(|b| b == &b'-' || b == &b'_') {
115            &base64::alphabet::URL_SAFE
116        } else {
117            &base64::alphabet::STANDARD
118        };
119
120        let engine = base64::engine::GeneralPurpose::new(alphabet, config);
121        let bytes = engine.decode(v.as_bytes()).map_err(serde::de::Error::custom)?;
122        self.tracker.set_target_vec(self.value, bytes);
123        Ok(())
124    }
125}
126
127impl<'de, T> TrackerDeserializer<'de> for BytesTracker<T>
128where
129    T: Expected,
130    BytesTracker<T>: Tracker<Target = T> + BytesLikeTracker,
131{
132    fn deserialize<D>(&mut self, value: &mut Self::Target, deserializer: D) -> Result<(), D::Error>
133    where
134        D: DeserializeContent<'de>,
135    {
136        deserializer.deserialize_seed(DeserializeHelper { value, tracker: self })
137    }
138}