1use std::io;
2
3#[derive(Debug)]
5#[must_use]
6pub struct BitReader<T> {
7 data: T,
8 bit_pos: u8,
9 current_byte: u8,
10}
11
12impl<T> BitReader<T> {
13 pub const fn new(data: T) -> Self {
15 Self {
16 data,
17 bit_pos: 0,
18 current_byte: 0,
19 }
20 }
21}
22
23impl<T: io::Read> BitReader<T> {
24 pub fn read_bit(&mut self) -> io::Result<bool> {
26 if self.is_aligned() {
27 self.update_byte()?;
28 }
29
30 let bit = (self.current_byte >> (7 - self.bit_pos)) & 1;
31
32 self.bit_pos = (self.bit_pos + 1) % 8;
33
34 Ok(bit == 1)
35 }
36
37 fn update_byte(&mut self) -> io::Result<()> {
38 let mut buf = [0];
39 self.data.read_exact(&mut buf)?;
40 self.current_byte = buf[0];
41 Ok(())
42 }
43
44 pub fn read_bits(&mut self, count: u8) -> io::Result<u64> {
46 let count = count.min(64);
47
48 let mut bits = 0;
49 for _ in 0..count {
50 let bit = self.read_bit()?;
51 bits <<= 1;
52 bits |= if bit { 1 } else { 0 };
53 }
54
55 Ok(bits)
56 }
57
58 #[inline(always)]
60 pub fn align(&mut self) -> io::Result<()> {
61 self.bit_pos = 0;
64 Ok(())
65 }
66}
67
68impl<T> BitReader<T> {
69 #[inline(always)]
71 #[must_use]
72 pub fn into_inner(self) -> T {
73 self.data
74 }
75
76 #[inline(always)]
78 #[must_use]
79 pub const fn get_ref(&self) -> &T {
80 &self.data
81 }
82
83 #[inline(always)]
85 #[must_use]
86 pub const fn bit_pos(&self) -> u8 {
87 self.bit_pos
88 }
89
90 #[inline(always)]
92 #[must_use]
93 pub const fn is_aligned(&self) -> bool {
94 self.bit_pos == 0
95 }
96}
97
98impl<T: io::Read> io::Read for BitReader<T> {
99 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
100 if self.is_aligned() {
103 return self.data.read(buf);
104 }
105
106 for byte in buf.iter_mut() {
117 *byte = 0;
118 for _ in 0..8 {
119 let bit = self.read_bit()?;
120 *byte <<= 1;
121 *byte |= bit as u8;
122 }
123 }
124
125 Ok(buf.len())
126 }
127}
128
129impl<B: AsRef<[u8]>> BitReader<std::io::Cursor<B>> {
130 pub const fn new_from_slice(data: B) -> Self {
132 Self::new(std::io::Cursor::new(data))
133 }
134}
135
136impl<W: io::Seek + io::Read> BitReader<W> {
137 pub fn bit_stream_position(&mut self) -> io::Result<u64> {
139 let pos = self.data.stream_position()?;
140 Ok(pos * 8 + if self.is_aligned() { 8 } else { self.bit_pos as u64 } - 8)
141 }
142
143 pub fn seek_bits(&mut self, count: i64) -> io::Result<u64> {
146 if count == 0 {
148 return self.bit_stream_position();
149 }
150
151 let count = self.bit_pos as i64 + count;
152
153 let bit_move = count % 8;
158 let mut byte_move = count / 8;
160
161 if !self.is_aligned() {
164 byte_move -= 1;
165 }
166
167 if bit_move < 0 {
170 byte_move -= 1;
171 }
172
173 let mut pos = self.data.seek(io::SeekFrom::Current(byte_move))? * 8;
174
175 self.bit_pos = ((8 + bit_move) % 8) as u8;
180
181 if !self.is_aligned() {
185 self.update_byte()?;
186 pos += self.bit_pos as u64;
187 }
188
189 Ok(pos)
190 }
191}
192
193impl<T: io::Seek + io::Read> io::Seek for BitReader<T> {
194 fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
195 match pos {
196 io::SeekFrom::Current(offset) if !self.is_aligned() => {
199 Ok(self.seek_bits(offset * 8)?.div_ceil(8))
201 }
202 _ => {
205 self.bit_pos = 0;
206 self.data.seek(pos)
207 }
208 }
209 }
210}
211
212#[cfg(test)]
213#[cfg_attr(all(test, coverage_nightly), coverage(off))]
214mod tests {
215 use io::{Read, Seek};
216
217 use super::*;
218
219 #[test]
220 fn test_bit_reader() {
221 let binary = 0b10101010110011001111000101010101u32;
222
223 let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
224 for i in 0..32 {
225 assert_eq!(
226 reader.read_bit().unwrap(),
227 (binary & (1 << (31 - i))) != 0,
228 "bit {i} is not correct"
229 );
230 }
231
232 assert!(reader.read_bit().is_err(), "there shouldnt be any bits left");
233 }
234
235 #[test]
236 fn test_bit_reader_read_bits() {
237 let binary = 0b10101010110011001111000101010101u32;
238 let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
239 let cases = [
240 (3, 0b101),
241 (4, 0b0101),
242 (3, 0b011),
243 (3, 0b001),
244 (3, 0b100),
245 (3, 0b111),
246 (5, 0b10001),
247 (1, 0b0),
248 (7, 0b1010101),
249 ];
250
251 for (i, (count, expected)) in cases.into_iter().enumerate() {
252 assert_eq!(
253 reader.read_bits(count).ok(),
254 Some(expected),
255 "reading {count} bits ({i}) are not correct"
256 );
257 }
258
259 assert!(reader.read_bit().is_err(), "there shouldnt be any bits left");
260 }
261
262 #[test]
263 fn test_bit_reader_align() {
264 let mut reader = BitReader::new_from_slice([0b10000000, 0b10000000, 0b10000000, 0b10000000, 0b10000000, 0b10000000]);
265
266 for i in 0..6 {
267 let pos = reader.data.stream_position().unwrap();
268 assert_eq!(pos, i, "stream pos");
269 assert_eq!(reader.bit_pos(), 0, "bit pos");
270 assert!(reader.read_bit().unwrap(), "bit {i} is not correct");
271 reader.align().unwrap();
272 let pos = reader.data.stream_position().unwrap();
273 assert_eq!(pos, i + 1, "stream pos");
274 assert_eq!(reader.bit_pos(), 0, "bit pos");
275 }
276
277 assert!(reader.read_bit().is_err(), "there shouldnt be any bits left");
278 }
279
280 #[test]
281 fn test_bit_reader_io_read() {
282 let binary = 0b10101010110011001111000101010101u32;
283 let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
284
285 let mut buf = [0; 1];
287 reader.read_exact(&mut buf).unwrap();
288 assert_eq!(buf, [0b10101010]);
289
290 assert_eq!(reader.read_bits(1).unwrap(), 0b1);
292 let mut buf = [0; 1];
293 reader.read_exact(&mut buf).unwrap();
294 assert_eq!(buf, [0b10011001]);
295 }
296
297 #[test]
298 fn test_bit_reader_seek() {
299 let binary = 0b10101010110011001111000101010101u32;
300 let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
301
302 assert_eq!(reader.seek_bits(5).unwrap(), 5);
303 assert_eq!(reader.data.stream_position().unwrap(), 1);
304 assert_eq!(reader.bit_pos(), 5);
305 assert_eq!(reader.read_bits(1).unwrap(), 0b0);
306 assert_eq!(reader.bit_pos(), 6);
307
308 assert_eq!(reader.seek_bits(0).unwrap(), 6);
309
310 assert_eq!(reader.seek_bits(10).unwrap(), 16);
311 assert_eq!(reader.data.stream_position().unwrap(), 2);
312 assert_eq!(reader.bit_pos(), 0);
313 assert_eq!(reader.read_bits(1).unwrap(), 0b1);
314 assert_eq!(reader.bit_pos(), 1);
315 assert_eq!(reader.data.stream_position().unwrap(), 3);
316
317 assert_eq!(reader.seek_bits(-8).unwrap(), 9);
318 assert_eq!(reader.data.stream_position().unwrap(), 2);
319 assert_eq!(reader.bit_pos(), 1);
320 assert_eq!(reader.read_bits(1).unwrap(), 0b1);
321 assert_eq!(reader.bit_pos(), 2);
322 assert_eq!(reader.data.stream_position().unwrap(), 2);
323
324 assert_eq!(reader.seek_bits(-2).unwrap(), 8);
325 assert_eq!(reader.data.stream_position().unwrap(), 1);
326 assert_eq!(reader.bit_pos(), 0);
327 assert_eq!(reader.read_bits(1).unwrap(), 0b1);
328 assert_eq!(reader.bit_pos(), 1);
329 assert_eq!(reader.data.stream_position().unwrap(), 2);
330 }
331
332 #[test]
333 fn test_bit_reader_io_seek() {
334 let binary = 0b10101010110011001111000101010101u32;
335 let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
336 assert_eq!(reader.seek(io::SeekFrom::Start(1)).unwrap(), 1);
337 assert_eq!(reader.bit_pos(), 0);
338 assert_eq!(reader.data.stream_position().unwrap(), 1);
339 assert_eq!(reader.read_bits(1).unwrap(), 0b1);
340 assert_eq!(reader.bit_pos(), 1);
341 assert_eq!(reader.data.stream_position().unwrap(), 2);
342
343 assert_eq!(reader.seek(io::SeekFrom::Current(1)).unwrap(), 3);
344 assert_eq!(reader.bit_pos(), 1);
345 assert_eq!(reader.data.stream_position().unwrap(), 3);
346 assert_eq!(reader.read_bits(1).unwrap(), 0b1);
347 assert_eq!(reader.bit_pos(), 2);
348 assert_eq!(reader.data.stream_position().unwrap(), 3);
349
350 assert_eq!(reader.seek(io::SeekFrom::Current(-1)).unwrap(), 2);
351 assert_eq!(reader.bit_pos(), 2);
352 assert_eq!(reader.data.stream_position().unwrap(), 2);
353 assert_eq!(reader.read_bits(1).unwrap(), 0b0);
354 assert_eq!(reader.bit_pos(), 3);
355 assert_eq!(reader.data.stream_position().unwrap(), 2);
356
357 assert_eq!(reader.seek(io::SeekFrom::End(-1)).unwrap(), 3);
358 assert_eq!(reader.bit_pos(), 0);
359 assert_eq!(reader.data.stream_position().unwrap(), 3);
360 assert_eq!(reader.read_bits(1).unwrap(), 0b0);
361 assert_eq!(reader.bit_pos(), 1);
362 assert_eq!(reader.data.stream_position().unwrap(), 4);
363 }
364}