1use std::collections::{BTreeMap, HashMap};
2use std::hash::BuildHasher;
3use std::marker::PhantomData;
4
5use super::{
6 DeserializeHelper, Expected, SerdePathToken, TrackedError, Tracker, TrackerDeserializer, TrackerFor, report_de_error,
7 report_tracked_error, set_irrecoverable,
8};
9
10pub struct MapTracker<K: Eq, T, M> {
11 map: linear_map::LinearMap<K, T>,
12 _marker: PhantomData<M>,
13}
14
15impl<K: Eq, T, M> std::ops::Deref for MapTracker<K, T, M> {
16 type Target = linear_map::LinearMap<K, T>;
17
18 fn deref(&self) -> &Self::Target {
19 &self.map
20 }
21}
22
23impl<K: Eq, T, M> std::ops::DerefMut for MapTracker<K, T, M> {
24 fn deref_mut(&mut self) -> &mut Self::Target {
25 &mut self.map
26 }
27}
28
29impl<K: Eq + std::fmt::Debug, T: std::fmt::Debug, M> std::fmt::Debug for MapTracker<K, T, M> {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 let mut map = f.debug_map();
32 for (key, value) in self.iter() {
33 map.entry(key, value);
34 }
35 map.finish()
36 }
37}
38
39impl<K: Eq, T, M> Default for MapTracker<K, T, M> {
40 fn default() -> Self {
41 Self {
42 _marker: PhantomData,
43 map: linear_map::LinearMap::new(),
44 }
45 }
46}
47
48pub(crate) trait Map<K, V> {
49 fn get_mut<'a>(&'a mut self, key: &K) -> Option<&'a mut V>;
50 fn insert(&mut self, key: K, value: V) -> Option<V>;
51 fn reserve(&mut self, additional: usize);
52}
53
54impl<K: Eq, T: Tracker, M: Default + Expected> Tracker for MapTracker<K, T, M> {
55 type Target = M;
56
57 fn allow_duplicates(&self) -> bool {
58 true
59 }
60}
61
62impl<K: std::hash::Hash + Eq + Expected, V: TrackerFor + Default + Expected, S: Default> TrackerFor for HashMap<K, V, S> {
63 type Tracker = MapTracker<K, V::Tracker, HashMap<K, <V::Tracker as Tracker>::Target, S>>;
64}
65
66impl<K: Ord + Expected, V: TrackerFor + Default + Expected> TrackerFor for BTreeMap<K, V> {
67 type Tracker = MapTracker<K, V::Tracker, BTreeMap<K, <V::Tracker as Tracker>::Target>>;
68}
69
70impl<K: std::hash::Hash + Eq, V: Default, S: BuildHasher> Map<K, V> for HashMap<K, V, S> {
71 fn get_mut(&mut self, key: &K) -> Option<&mut V> {
72 HashMap::get_mut(self, key)
73 }
74
75 fn insert(&mut self, key: K, value: V) -> Option<V> {
76 HashMap::insert(self, key, value)
77 }
78
79 fn reserve(&mut self, additional: usize) {
80 HashMap::reserve(self, additional)
81 }
82}
83
84impl<K: Ord, V: Default> Map<K, V> for BTreeMap<K, V> {
85 fn get_mut(&mut self, key: &K) -> Option<&mut V> {
86 BTreeMap::get_mut(self, key)
87 }
88
89 fn insert(&mut self, key: K, value: V) -> Option<V> {
90 BTreeMap::insert(self, key, value)
91 }
92
93 fn reserve(&mut self, _: usize) {}
94}
95
96impl<'de, K, T, M> serde::de::DeserializeSeed<'de> for DeserializeHelper<'_, MapTracker<K, T, M>>
97where
98 for<'a> DeserializeHelper<'a, T>: serde::de::DeserializeSeed<'de, Value = ()>,
99 T: Tracker + Default,
100 K: serde::de::Deserialize<'de> + std::cmp::Eq + Clone + std::fmt::Debug + Expected,
101 M: Map<K, T::Target>,
102 MapTracker<K, T, M>: Tracker<Target = M>,
103 T::Target: Default,
104{
105 type Value = ();
106
107 fn deserialize<D>(self, de: D) -> Result<Self::Value, D::Error>
108 where
109 D: serde::Deserializer<'de>,
110 {
111 de.deserialize_map(self)
112 }
113}
114
115impl<'de, K, T, M> serde::de::Visitor<'de> for DeserializeHelper<'_, MapTracker<K, T, M>>
116where
117 for<'a> DeserializeHelper<'a, T>: serde::de::DeserializeSeed<'de, Value = ()>,
118 T: Tracker + Default,
119 K: serde::de::Deserialize<'de> + std::cmp::Eq + Clone + std::fmt::Debug + Expected,
120 M: Map<K, T::Target>,
121 MapTracker<K, T, M>: Tracker<Target = M>,
122 T::Target: Default,
123{
124 type Value = ();
125
126 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
127 HashMap::<K, T::Target>::expecting(formatter)
128 }
129
130 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
131 where
132 A: serde::de::MapAccess<'de>,
133 {
134 if let Some(size) = map.size_hint() {
135 self.tracker.reserve(size);
136 self.value.reserve(size);
137 }
138
139 let mut new_value = T::Target::default();
140
141 while let Some(key) = map.next_key::<K>().inspect_err(|_| {
142 set_irrecoverable();
143 })? {
144 let _token = SerdePathToken::push_key(&key);
145 let entry = self.tracker.entry(key.clone());
146 if let linear_map::Entry::Occupied(entry) = &entry {
147 if !entry.get().allow_duplicates() {
148 report_tracked_error(TrackedError::duplicate_field())?;
149 map.next_value::<serde::de::IgnoredAny>().inspect_err(|_| {
150 set_irrecoverable();
151 })?;
152 continue;
153 }
154 }
155
156 let tracker = entry.or_insert_with(Default::default);
157 let value = self.value.get_mut(&key);
158 let used_new = value.is_none();
159 let value = value.unwrap_or(&mut new_value);
160 match map.next_value_seed(DeserializeHelper { value, tracker }) {
161 Ok(_) => {}
162 Err(error) => {
163 report_de_error(error)?;
164 continue;
165 }
166 }
167
168 drop(_token);
169
170 if used_new {
171 self.value.insert(key, std::mem::take(&mut new_value));
172 }
173 }
174
175 Ok(())
176 }
177}
178
179impl<'de, K, T, M> TrackerDeserializer<'de> for MapTracker<K, T, M>
180where
181 for<'a> DeserializeHelper<'a, T>: serde::de::DeserializeSeed<'de, Value = ()>,
182 T: Tracker + Default,
183 K: serde::de::Deserialize<'de> + std::cmp::Eq + Clone + std::fmt::Debug + Expected,
184 M: Map<K, T::Target>,
185 MapTracker<K, T, M>: Tracker<Target = M>,
186 T::Target: Default,
187{
188 fn deserialize<D>(&mut self, value: &mut Self::Target, deserializer: D) -> Result<(), D::Error>
189 where
190 D: super::DeserializeContent<'de>,
191 {
192 deserializer.deserialize_seed(DeserializeHelper { value, tracker: self })
193 }
194}