tinc/private/
map.rs

1use std::collections::{BTreeMap, HashMap};
2use std::hash::BuildHasher;
3use std::marker::PhantomData;
4
5use super::{
6    DeserializeHelper, Expected, SerdePathToken, TrackedError, Tracker, TrackerDeserializer, TrackerFor, report_de_error,
7    report_tracked_error, set_irrecoverable,
8};
9
10pub struct MapTracker<K: Eq, T, M> {
11    map: linear_map::LinearMap<K, T>,
12    _marker: PhantomData<M>,
13}
14
15impl<K: Eq, T, M> std::ops::Deref for MapTracker<K, T, M> {
16    type Target = linear_map::LinearMap<K, T>;
17
18    fn deref(&self) -> &Self::Target {
19        &self.map
20    }
21}
22
23impl<K: Eq, T, M> std::ops::DerefMut for MapTracker<K, T, M> {
24    fn deref_mut(&mut self) -> &mut Self::Target {
25        &mut self.map
26    }
27}
28
29impl<K: Eq + std::fmt::Debug, T: std::fmt::Debug, M> std::fmt::Debug for MapTracker<K, T, M> {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        let mut map = f.debug_map();
32        for (key, value) in self.iter() {
33            map.entry(key, value);
34        }
35        map.finish()
36    }
37}
38
39impl<K: Eq, T, M> Default for MapTracker<K, T, M> {
40    fn default() -> Self {
41        Self {
42            _marker: PhantomData,
43            map: linear_map::LinearMap::new(),
44        }
45    }
46}
47
48pub(crate) trait Map<K, V> {
49    fn get_mut<'a>(&'a mut self, key: &K) -> Option<&'a mut V>;
50    fn insert(&mut self, key: K, value: V) -> Option<V>;
51    fn reserve(&mut self, additional: usize);
52}
53
54impl<K: Eq, T: Tracker, M: Default + Expected> Tracker for MapTracker<K, T, M> {
55    type Target = M;
56
57    fn allow_duplicates(&self) -> bool {
58        true
59    }
60}
61
62impl<K: std::hash::Hash + Eq + Expected, V: TrackerFor + Default + Expected, S: Default> TrackerFor for HashMap<K, V, S> {
63    type Tracker = MapTracker<K, V::Tracker, HashMap<K, <V::Tracker as Tracker>::Target, S>>;
64}
65
66impl<K: Ord + Expected, V: TrackerFor + Default + Expected> TrackerFor for BTreeMap<K, V> {
67    type Tracker = MapTracker<K, V::Tracker, BTreeMap<K, <V::Tracker as Tracker>::Target>>;
68}
69
70impl<K: std::hash::Hash + Eq, V: Default, S: BuildHasher> Map<K, V> for HashMap<K, V, S> {
71    fn get_mut(&mut self, key: &K) -> Option<&mut V> {
72        HashMap::get_mut(self, key)
73    }
74
75    fn insert(&mut self, key: K, value: V) -> Option<V> {
76        HashMap::insert(self, key, value)
77    }
78
79    fn reserve(&mut self, additional: usize) {
80        HashMap::reserve(self, additional)
81    }
82}
83
84impl<K: Ord, V: Default> Map<K, V> for BTreeMap<K, V> {
85    fn get_mut(&mut self, key: &K) -> Option<&mut V> {
86        BTreeMap::get_mut(self, key)
87    }
88
89    fn insert(&mut self, key: K, value: V) -> Option<V> {
90        BTreeMap::insert(self, key, value)
91    }
92
93    fn reserve(&mut self, _: usize) {}
94}
95
96impl<'de, K, T, M> serde::de::DeserializeSeed<'de> for DeserializeHelper<'_, MapTracker<K, T, M>>
97where
98    for<'a> DeserializeHelper<'a, T>: serde::de::DeserializeSeed<'de, Value = ()>,
99    T: Tracker + Default,
100    K: serde::de::Deserialize<'de> + std::cmp::Eq + Clone + std::fmt::Debug + Expected,
101    M: Map<K, T::Target>,
102    MapTracker<K, T, M>: Tracker<Target = M>,
103    T::Target: Default,
104{
105    type Value = ();
106
107    fn deserialize<D>(self, de: D) -> Result<Self::Value, D::Error>
108    where
109        D: serde::Deserializer<'de>,
110    {
111        de.deserialize_map(self)
112    }
113}
114
115impl<'de, K, T, M> serde::de::Visitor<'de> for DeserializeHelper<'_, MapTracker<K, T, M>>
116where
117    for<'a> DeserializeHelper<'a, T>: serde::de::DeserializeSeed<'de, Value = ()>,
118    T: Tracker + Default,
119    K: serde::de::Deserialize<'de> + std::cmp::Eq + Clone + std::fmt::Debug + Expected,
120    M: Map<K, T::Target>,
121    MapTracker<K, T, M>: Tracker<Target = M>,
122    T::Target: Default,
123{
124    type Value = ();
125
126    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
127        HashMap::<K, T::Target>::expecting(formatter)
128    }
129
130    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
131    where
132        A: serde::de::MapAccess<'de>,
133    {
134        if let Some(size) = map.size_hint() {
135            self.tracker.reserve(size);
136            self.value.reserve(size);
137        }
138
139        let mut new_value = T::Target::default();
140
141        while let Some(key) = map.next_key::<K>().inspect_err(|_| {
142            set_irrecoverable();
143        })? {
144            let _token = SerdePathToken::push_key(&key);
145            let entry = self.tracker.entry(key.clone());
146            if let linear_map::Entry::Occupied(entry) = &entry {
147                if !entry.get().allow_duplicates() {
148                    report_tracked_error(TrackedError::duplicate_field())?;
149                    map.next_value::<serde::de::IgnoredAny>().inspect_err(|_| {
150                        set_irrecoverable();
151                    })?;
152                    continue;
153                }
154            }
155
156            let tracker = entry.or_insert_with(Default::default);
157            let value = self.value.get_mut(&key);
158            let used_new = value.is_none();
159            let value = value.unwrap_or(&mut new_value);
160            match map.next_value_seed(DeserializeHelper { value, tracker }) {
161                Ok(_) => {}
162                Err(error) => {
163                    report_de_error(error)?;
164                    continue;
165                }
166            }
167
168            drop(_token);
169
170            if used_new {
171                self.value.insert(key, std::mem::take(&mut new_value));
172            }
173        }
174
175        Ok(())
176    }
177}
178
179impl<'de, K, T, M> TrackerDeserializer<'de> for MapTracker<K, T, M>
180where
181    for<'a> DeserializeHelper<'a, T>: serde::de::DeserializeSeed<'de, Value = ()>,
182    T: Tracker + Default,
183    K: serde::de::Deserialize<'de> + std::cmp::Eq + Clone + std::fmt::Debug + Expected,
184    M: Map<K, T::Target>,
185    MapTracker<K, T, M>: Tracker<Target = M>,
186    T::Target: Default,
187{
188    fn deserialize<D>(&mut self, value: &mut Self::Target, deserializer: D) -> Result<(), D::Error>
189    where
190        D: super::DeserializeContent<'de>,
191    {
192        deserializer.deserialize_seed(DeserializeHelper { value, tracker: self })
193    }
194}