1use std::marker::PhantomData;
2
3use serde::de::{Unexpected, VariantAccess};
4
5use super::{
6 DeserializeContent, DeserializeHelper, Expected, IdentifiedValue, Identifier, IdentifierDeserializer, IdentifierFor,
7 MapAccessValueDeserializer, SerdeDeserializer, SerdePathToken, TrackedError, Tracker, TrackerDeserializer, TrackerFor,
8 TrackerWrapper, report_de_error, report_tracked_error, set_irrecoverable,
9};
10
11pub trait OneOfHelper {
12 type Target;
13}
14
15impl<T> OneOfHelper for Option<T> {
16 type Target = T;
17}
18
19pub trait TaggedOneOfIdentifier: Identifier {
20 const TAG: Self;
21 const CONTENT: Self;
22}
23
24pub trait TrackerDeserializeIdentifier<'de>: Tracker
25where
26 Self::Target: IdentifierFor,
27{
28 fn deserialize<D>(
29 &mut self,
30 value: &mut Self::Target,
31 identifier: <Self::Target as IdentifierFor>::Identifier,
32 deserializer: D,
33 ) -> Result<(), D::Error>
34 where
35 D: DeserializeContent<'de>;
36}
37
38pub trait TrackedOneOfVariant {
39 type Variant: Identifier;
40}
41
42pub trait TrackedOneOfDeserializer<'de>: TrackerFor + IdentifierFor + TrackedOneOfVariant + Sized
43where
44 Self::Tracker: TrackerWrapper,
45{
46 const DENY_UNKNOWN_FIELDS: bool = false;
47
48 fn deserialize<D>(
49 value: &mut Option<Self>,
50 identifier: Self::Variant,
51 tracker: &mut Option<<Self::Tracker as TrackerWrapper>::Tracker>,
52 deserializer: D,
53 ) -> Result<(), D::Error>
54 where
55 D: DeserializeContent<'de>;
56
57 fn tracker_to_identifier(tracker: &<Self::Tracker as TrackerWrapper>::Tracker) -> Self::Variant;
58 fn value_to_identifier(value: &Self) -> Self::Variant;
59}
60
61impl<'de, T> serde::de::Visitor<'de> for DeserializeHelper<'_, TaggedOneOfTracker<T>>
62where
63 T: Tracker,
64 T::Target: TrackedOneOfDeserializer<'de, Tracker = TaggedOneOfTracker<T>>,
65 T::Target: IdentifierFor,
66 <T::Target as IdentifierFor>::Identifier: TaggedOneOfIdentifier,
67{
68 type Value = ();
69
70 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
71 <T::Target as Expected>::expecting(formatter)
72 }
73
74 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
75 where
76 A: serde::de::MapAccess<'de>,
77 {
78 while let Some(key) = map
79 .next_key_seed(IdentifierDeserializer::<<T::Target as IdentifierFor>::Identifier>::new())
80 .inspect_err(|_| {
81 set_irrecoverable();
82 })?
83 {
84 let _token = SerdePathToken::push_field(match &key {
85 IdentifiedValue::Found(tag) => tag.name(),
86 IdentifiedValue::Unknown(v) => v.as_ref(),
87 });
88
89 let mut deserialized = false;
90
91 match &key {
92 IdentifiedValue::Found(tag) => {
93 TrackerDeserializeIdentifier::deserialize(
94 self.tracker,
95 self.value,
96 *tag,
97 MapAccessValueDeserializer {
98 map: &mut map,
99 deserialized: &mut deserialized,
100 },
101 )?;
102 }
103 IdentifiedValue::Unknown(_) => {
104 report_tracked_error(TrackedError::unknown_field(T::Target::DENY_UNKNOWN_FIELDS))?;
105 }
106 }
107
108 if !deserialized {
109 map.next_value::<serde::de::IgnoredAny>().inspect_err(|_| {
110 set_irrecoverable();
111 })?;
112 }
113 }
114
115 Ok(())
116 }
117}
118
119impl<'de, T> TrackerDeserializer<'de> for OneOfTracker<T>
120where
121 T: Tracker,
122 T::Target: TrackedOneOfDeserializer<'de, Tracker = OneOfTracker<T>, Variant = <T::Target as IdentifierFor>::Identifier>,
123{
124 fn deserialize<D>(&mut self, value: &mut Self::Target, deserializer: D) -> Result<(), D::Error>
125 where
126 D: DeserializeContent<'de>,
127 {
128 deserializer.deserialize_seed(DeserializeHelper { value, tracker: self })
129 }
130}
131
132pub struct TrackerForOneOf<T>(PhantomData<T>);
133
134impl<T: TrackerFor> TrackerFor for TrackerForOneOf<T> {
135 type Tracker = OneOfTracker<T::Tracker>;
136}
137
138const TAGGED_ONE_OF_TRACKER_STATE_TAG_INVALID: u8 = 0b00000001;
139const TAGGED_ONE_OF_TRACKER_STATE_HAS_CONTENT: u8 = 0b00000010;
140
141pub struct TaggedOneOfTracker<T>
142where
143 T: Tracker,
144 T::Target: TrackedOneOfVariant,
145{
146 tracker: Option<T>,
147 state: u8,
148 tag_buffer: Option<<T::Target as TrackedOneOfVariant>::Variant>,
149 content_buffer: Vec<serde_json::Value>,
150}
151
152impl<T: Tracker> TrackerWrapper for TaggedOneOfTracker<T>
153where
154 T::Target: TrackedOneOfVariant,
155{
156 type Tracker = T;
157}
158
159impl<'de, T> TrackerDeserializeIdentifier<'de> for TaggedOneOfTracker<T>
160where
161 T: Tracker,
162 T::Target: TrackedOneOfVariant + IdentifierFor,
163 <T::Target as IdentifierFor>::Identifier: TaggedOneOfIdentifier,
164 T::Target: TrackedOneOfDeserializer<'de, Tracker = TaggedOneOfTracker<T>>,
165{
166 fn deserialize<D>(
167 &mut self,
168 value: &mut Self::Target,
169 identifier: <Self::Target as IdentifierFor>::Identifier,
170 deserializer: D,
171 ) -> Result<(), D::Error>
172 where
173 D: DeserializeContent<'de>,
174 {
175 if identifier == <T::Target as IdentifierFor>::Identifier::TAG {
176 let tag = deserializer.deserialize_seed(IdentifierDeserializer::new())?;
177 match (tag, self.tag_buffer) {
178 (IdentifiedValue::Found(tag), _) if !self.tag_invalid() => {
179 if let Some(existing_tag) = self.tag_buffer {
180 if existing_tag != tag {
181 let error = <D::Error as serde::de::Error>::invalid_value(
182 Unexpected::Str(tag.name()),
183 &existing_tag.name(),
184 );
185 report_de_error(error)?;
186 }
187 } else {
188 self.tag_buffer = Some(tag);
189 }
190
191 let _token = SerdePathToken::replace_field(<T::Target as IdentifierFor>::Identifier::CONTENT.name());
192 for content in self.content_buffer.drain(..) {
193 let result: Result<(), D::Error> = T::Target::deserialize(
194 value,
195 tag,
196 &mut self.tracker,
197 SerdeDeserializer {
198 deserializer: serde::de::IntoDeserializer::into_deserializer(content),
199 },
200 )
201 .map_err(serde::de::Error::custom);
202
203 if let Err(e) = result {
204 report_de_error(e)?;
205 }
206 }
207 }
208 (IdentifiedValue::Unknown(v), None) => {
209 self.set_tag_invalid();
210 let error = <D::Error as serde::de::Error>::unknown_variant(
211 v.as_ref(),
212 <T::Target as TrackedOneOfVariant>::Variant::OPTIONS,
213 );
214 report_de_error(error)?;
215 }
216 (IdentifiedValue::Unknown(v), Some(tag)) => {
217 self.set_tag_invalid();
218 let error = <D::Error as serde::de::Error>::invalid_value(Unexpected::Str(v.as_ref()), &tag.name());
219 report_de_error(error)?;
220 }
221 _ => {}
222 }
223 } else if identifier == <T::Target as IdentifierFor>::Identifier::CONTENT {
224 self.set_has_content();
225 if !self.tag_invalid() {
226 if let Some(tag) = self.tag_buffer {
227 let result: Result<(), D::Error> = T::Target::deserialize(value, tag, &mut self.tracker, deserializer);
228 if let Err(e) = result {
229 report_de_error(e)?;
230 }
231 } else {
232 self.content_buffer
233 .push(deserializer.deserialize::<serde_json::Value>().inspect_err(|_| {
234 set_irrecoverable();
235 })?);
236 }
237 }
238 } else {
239 report_tracked_error(TrackedError::unknown_field(T::Target::DENY_UNKNOWN_FIELDS))?;
240 }
241
242 Ok(())
243 }
244}
245
246impl<'de, T> TrackerDeserializer<'de> for TaggedOneOfTracker<T>
247where
248 T: Tracker,
249 T::Target: TrackedOneOfDeserializer<'de, Tracker = TaggedOneOfTracker<T>>,
250 <T::Target as IdentifierFor>::Identifier: TaggedOneOfIdentifier,
251{
252 fn deserialize<D>(&mut self, value: &mut Self::Target, deserializer: D) -> Result<(), D::Error>
253 where
254 D: DeserializeContent<'de>,
255 {
256 deserializer.deserialize_seed(DeserializeHelper { value, tracker: self })
257 }
258}
259
260impl<T> std::ops::Deref for TaggedOneOfTracker<T>
261where
262 T: Tracker,
263 T::Target: TrackedOneOfVariant,
264{
265 type Target = Option<T>;
266
267 fn deref(&self) -> &Self::Target {
268 &self.tracker
269 }
270}
271
272impl<T> std::ops::DerefMut for TaggedOneOfTracker<T>
273where
274 T: Tracker,
275 T::Target: TrackedOneOfVariant,
276{
277 fn deref_mut(&mut self) -> &mut Self::Target {
278 &mut self.tracker
279 }
280}
281
282impl<T> std::fmt::Debug for TaggedOneOfTracker<T>
283where
284 T: Tracker + std::fmt::Debug,
285 T::Target: TrackedOneOfVariant,
286{
287 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
288 f.debug_struct("TaggedOneOfTracker")
289 .field("tracker", &self.tracker)
290 .field("state", &self.state)
291 .field("tag_buffer", &self.tag_buffer.map(|t| t.name()))
292 .field("value_buffer", &self.content_buffer)
293 .finish()
294 }
295}
296
297impl<T> Default for TaggedOneOfTracker<T>
298where
299 T: Tracker,
300 T::Target: TrackedOneOfVariant,
301{
302 fn default() -> Self {
303 Self {
304 tracker: None,
305 state: 0,
306 tag_buffer: None,
307 content_buffer: Vec::new(),
308 }
309 }
310}
311
312impl<T> TaggedOneOfTracker<T>
313where
314 T: Tracker,
315 T::Target: TrackedOneOfVariant,
316{
317 pub fn tag_invalid(&self) -> bool {
318 self.state & TAGGED_ONE_OF_TRACKER_STATE_TAG_INVALID != 0
319 }
320
321 pub fn set_tag_invalid(&mut self) {
322 self.state |= TAGGED_ONE_OF_TRACKER_STATE_TAG_INVALID;
323 }
324
325 pub fn has_content(&self) -> bool {
326 self.state & TAGGED_ONE_OF_TRACKER_STATE_HAS_CONTENT != 0
327 }
328
329 pub fn set_has_content(&mut self) {
330 self.state |= TAGGED_ONE_OF_TRACKER_STATE_HAS_CONTENT;
331 }
332}
333
334impl<T> Tracker for TaggedOneOfTracker<T>
335where
336 T: Tracker,
337 T::Target: TrackedOneOfVariant,
338{
339 type Target = Option<T::Target>;
340
341 fn allow_duplicates(&self) -> bool {
342 self.tracker.as_ref().is_none_or(|t| t.allow_duplicates())
343 }
344}
345
346impl<'de, T> serde::de::DeserializeSeed<'de> for DeserializeHelper<'_, TaggedOneOfTracker<T>>
347where
348 T: Tracker,
349 T::Target: TrackedOneOfDeserializer<'de, Tracker = TaggedOneOfTracker<T>>,
350 <T::Target as IdentifierFor>::Identifier: TaggedOneOfIdentifier,
351{
352 type Value = ();
353
354 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
355 where
356 D: serde::Deserializer<'de>,
357 {
358 deserializer.deserialize_struct(T::Target::NAME, <T::Target as IdentifierFor>::Identifier::OPTIONS, self)
359 }
360}
361
362#[derive(Debug)]
363pub struct OneOfTracker<T>(pub Option<T>);
364
365impl<T: Tracker> TrackerWrapper for OneOfTracker<T> {
366 type Tracker = T;
367}
368
369impl<T> std::ops::Deref for OneOfTracker<T> {
370 type Target = Option<T>;
371
372 fn deref(&self) -> &Self::Target {
373 &self.0
374 }
375}
376
377impl<T> std::ops::DerefMut for OneOfTracker<T> {
378 fn deref_mut(&mut self) -> &mut Self::Target {
379 &mut self.0
380 }
381}
382
383impl<T> Default for OneOfTracker<T> {
384 fn default() -> Self {
385 Self(None)
386 }
387}
388
389impl<T: Tracker> Tracker for OneOfTracker<T> {
390 type Target = Option<T::Target>;
391
392 fn allow_duplicates(&self) -> bool {
393 self.0.as_ref().is_none_or(|value| value.allow_duplicates())
394 }
395}
396
397impl<'de, T> serde::de::DeserializeSeed<'de> for DeserializeHelper<'_, OneOfTracker<T>>
398where
399 T: Tracker,
400 T::Target: TrackedOneOfDeserializer<'de, Tracker = OneOfTracker<T>, Variant = <T::Target as IdentifierFor>::Identifier>,
401{
402 type Value = ();
403
404 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
405 where
406 D: serde::Deserializer<'de>,
407 {
408 deserializer.deserialize_enum(T::Target::NAME, <T::Target as IdentifierFor>::Identifier::OPTIONS, self)
409 }
410}
411
412impl<'de, T> serde::de::Visitor<'de> for DeserializeHelper<'_, OneOfTracker<T>>
413where
414 T: Tracker,
415 T::Target: TrackedOneOfDeserializer<'de, Tracker = OneOfTracker<T>, Variant = <T::Target as IdentifierFor>::Identifier>,
416{
417 type Value = ();
418
419 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
420 write!(formatter, "one of")
421 }
422
423 fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
424 where
425 A: serde::de::EnumAccess<'de>,
426 {
427 let (variant, variant_access) =
428 data.variant_seed(IdentifierDeserializer::<<T::Target as IdentifierFor>::Identifier>::new())?;
429 match variant {
430 IdentifiedValue::Found(variant) => {
431 let _token = SerdePathToken::push_field(variant.name());
432 TrackerDeserializeIdentifier::deserialize(
433 self.tracker,
434 self.value,
435 variant,
436 VariantAccessDeserializer { de: variant_access },
437 )
438 }
439 IdentifiedValue::Unknown(variant) => {
440 let error = <A::Error as serde::de::Error>::unknown_variant(
441 variant.as_ref(),
442 <T::Target as IdentifierFor>::Identifier::OPTIONS,
443 );
444 report_de_error(error)?;
445 variant_access.newtype_variant::<serde::de::IgnoredAny>().inspect_err(|_| {
446 set_irrecoverable();
447 })?;
448 Ok(())
449 }
450 }
451 }
452}
453
454impl<'de, T> TrackerDeserializeIdentifier<'de> for OneOfTracker<T>
455where
456 T: Tracker,
457 T::Target: TrackedOneOfDeserializer<'de, Tracker = OneOfTracker<T>, Variant = <T::Target as IdentifierFor>::Identifier>,
458{
459 fn deserialize<D>(
460 &mut self,
461 value: &mut Self::Target,
462 identifier: <Self::Target as IdentifierFor>::Identifier,
463 deserializer: D,
464 ) -> Result<(), D::Error>
465 where
466 D: DeserializeContent<'de>,
467 {
468 T::Target::deserialize(value, identifier, self, deserializer)
469 }
470}
471
472struct VariantAccessDeserializer<D> {
473 de: D,
474}
475
476impl<'de, D> DeserializeContent<'de> for VariantAccessDeserializer<D>
477where
478 D: serde::de::VariantAccess<'de>,
479{
480 type Error = D::Error;
481
482 fn deserialize_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
483 where
484 T: serde::de::DeserializeSeed<'de>,
485 {
486 self.de.newtype_variant_seed(seed)
487 }
488}