scuffle_av1/obu/
mod.rs

1use std::io;
2
3use scuffle_bytes_util::BitReader;
4use utils::read_leb128;
5
6pub mod seq;
7mod utils;
8
9/// OBU Header
10/// AV1-Spec-2 - 5.3.2
11#[derive(Debug, Clone, PartialEq, Eq, Copy)]
12pub struct ObuHeader {
13    /// `obu_type`
14    ///
15    /// 4 bits
16    pub obu_type: ObuType,
17    /// `obu_size` if `obu_has_size_field` is 1
18    ///
19    /// leb128()
20    pub size: Option<u64>,
21    /// `obu_extension_header()` if `obu_extension_flag` is 1
22    pub extension_header: Option<ObuExtensionHeader>,
23}
24
25/// Obu Header Extension
26/// AV1-Spec-2 - 5.3.3
27#[derive(Debug, Clone, PartialEq, Eq, Copy)]
28pub struct ObuExtensionHeader {
29    /// `temporal_id`
30    pub temporal_id: u8,
31    /// `spatial_id`
32    pub spatial_id: u8,
33}
34
35impl ObuHeader {
36    /// Parses an OBU header from the given `cursor`.
37    pub fn parse(cursor: &mut impl io::Read) -> io::Result<Self> {
38        let mut bit_reader = BitReader::new(cursor);
39        let forbidden_bit = bit_reader.read_bit()?;
40        if forbidden_bit {
41            return Err(io::Error::new(io::ErrorKind::InvalidData, "obu_forbidden_bit is not 0"));
42        }
43
44        let obu_type = bit_reader.read_bits(4)?;
45        let extension_flag = bit_reader.read_bit()?;
46        let has_size_field = bit_reader.read_bit()?;
47
48        bit_reader.read_bit()?; // reserved_1bit
49
50        let extension_header = if extension_flag {
51            let temporal_id = bit_reader.read_bits(3)?;
52            let spatial_id = bit_reader.read_bits(2)?;
53            bit_reader.read_bits(3)?; // reserved_3bits
54            Some(ObuExtensionHeader {
55                temporal_id: temporal_id as u8,
56                spatial_id: spatial_id as u8,
57            })
58        } else {
59            None
60        };
61
62        let size = if has_size_field {
63            // obu_size
64            Some(read_leb128(&mut bit_reader)?)
65        } else {
66            None
67        };
68
69        if !bit_reader.is_aligned() {
70            return Err(io::Error::new(io::ErrorKind::InvalidData, "bit reader is not aligned"));
71        }
72
73        Ok(ObuHeader {
74            obu_type: ObuType::from(obu_type as u8),
75            size,
76            extension_header,
77        })
78    }
79}
80
81/// OBU Type
82/// AV1-Spec-2 - 6.2.2
83#[derive(Debug, Clone, PartialEq, Eq, Copy)]
84pub enum ObuType {
85    /// `OBU_SEQUENCE_HEADER`
86    SequenceHeader,
87    /// `OBU_TEMPORAL_DELIMITER`
88    TemporalDelimiter,
89    /// `OBU_FRAME_HEADER`
90    FrameHeader,
91    /// `OBU_TILE_GROUP`
92    TileGroup,
93    /// `OBU_METADATA`
94    Metadata,
95    /// `OBU_FRAME`
96    Frame,
97    /// `OBU_REDUNDANT_FRAME_HEADER`
98    RedundantFrameHeader,
99    /// `OBU_TILE_LIST`
100    TileList,
101    /// `OBU_PADDING`
102    Padding,
103    /// Reserved
104    Reserved(u8),
105}
106
107impl From<u8> for ObuType {
108    fn from(value: u8) -> Self {
109        match value {
110            1 => ObuType::SequenceHeader,
111            2 => ObuType::TemporalDelimiter,
112            3 => ObuType::FrameHeader,
113            4 => ObuType::TileGroup,
114            5 => ObuType::Metadata,
115            6 => ObuType::Frame,
116            7 => ObuType::RedundantFrameHeader,
117            8 => ObuType::TileList,
118            15 => ObuType::Padding,
119            _ => ObuType::Reserved(value),
120        }
121    }
122}
123
124impl From<ObuType> for u8 {
125    fn from(value: ObuType) -> Self {
126        match value {
127            ObuType::SequenceHeader => 1,
128            ObuType::TemporalDelimiter => 2,
129            ObuType::FrameHeader => 3,
130            ObuType::TileGroup => 4,
131            ObuType::Metadata => 5,
132            ObuType::Frame => 6,
133            ObuType::RedundantFrameHeader => 7,
134            ObuType::TileList => 8,
135            ObuType::Padding => 15,
136            ObuType::Reserved(value) => value,
137        }
138    }
139}
140
141#[cfg(test)]
142#[cfg_attr(all(coverage_nightly, test), coverage(off))]
143mod tests {
144    use bytes::Buf;
145
146    use super::*;
147
148    #[test]
149    fn test_obu_header_parse() {
150        let mut cursor = std::io::Cursor::new(b"\n\x0f\0\0\0j\xef\xbf\xe1\xbc\x02\x19\x90\x10\x10\x10@");
151        let header = ObuHeader::parse(&mut cursor).unwrap();
152        insta::assert_debug_snapshot!(header, @r"
153        ObuHeader {
154            obu_type: SequenceHeader,
155            size: Some(
156                15,
157            ),
158            extension_header: None,
159        }
160        ");
161
162        assert_eq!(cursor.position(), 2);
163        assert_eq!(cursor.remaining(), 15);
164    }
165
166    #[test]
167    fn test_obu_header_parse_no_size_field() {
168        let mut cursor = std::io::Cursor::new(b"\x00");
169        let header = ObuHeader::parse(&mut cursor).unwrap();
170        insta::assert_debug_snapshot!(header, @r"
171        ObuHeader {
172            obu_type: Reserved(
173                0,
174            ),
175            size: None,
176            extension_header: None,
177        }
178        ");
179
180        assert_eq!(cursor.position(), 1);
181        assert_eq!(cursor.remaining(), 0);
182    }
183
184    #[test]
185    fn test_obu_header_parse_extension_header() {
186        let mut cursor = std::io::Cursor::new([0b00000100, 0b11010000]);
187        let header = ObuHeader::parse(&mut cursor).unwrap();
188        insta::assert_debug_snapshot!(header, @r"
189        ObuHeader {
190            obu_type: Reserved(
191                0,
192            ),
193            size: None,
194            extension_header: Some(
195                ObuExtensionHeader {
196                    temporal_id: 6,
197                    spatial_id: 2,
198                },
199            ),
200        }
201        ");
202
203        assert_eq!(cursor.position(), 2);
204        assert_eq!(cursor.remaining(), 0);
205    }
206
207    #[test]
208    fn test_obu_header_forbidden_bit_set() {
209        let err = ObuHeader::parse(&mut std::io::Cursor::new(
210            b"\xff\x0f\0\0\0j\xef\xbf\xe1\xbc\x02\x19\x90\x10\x10\x10@",
211        ))
212        .unwrap_err();
213        insta::assert_debug_snapshot!(err, @r#"
214        Custom {
215            kind: InvalidData,
216            error: "obu_forbidden_bit is not 0",
217        }
218        "#);
219    }
220
221    #[test]
222    fn test_obu_to_from_u8() {
223        let case = [
224            (ObuType::SequenceHeader, 1),
225            (ObuType::TemporalDelimiter, 2),
226            (ObuType::FrameHeader, 3),
227            (ObuType::TileGroup, 4),
228            (ObuType::Metadata, 5),
229            (ObuType::Frame, 6),
230            (ObuType::RedundantFrameHeader, 7),
231            (ObuType::TileList, 8),
232            (ObuType::Padding, 15),
233            (ObuType::Reserved(0), 0),
234            (ObuType::Reserved(100), 100),
235        ];
236
237        for (obu_type, value) in case {
238            assert_eq!(u8::from(obu_type), value);
239            assert_eq!(ObuType::from(value), obu_type);
240        }
241    }
242}