github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/bit/bit.go (about)

     1  // Package bit implements bit level encoding wrappers for io.Reader / io.Writer
     2  // All operations are Little Endian where applicable
     3  //
     4  // Note that Write/Read operations do not return error for convenience,
     5  // you must check bit.Reader.Error() or bit.Writer.Error() manually
     6  //
     7  package bit
     8  
     9  import "io"
    10  
    11  type Writer struct {
    12  	w     io.Writer
    13  	bits  uint64
    14  	nbits uint
    15  	err   error
    16  }
    17  
    18  func NewWriter(w io.Writer) *Writer { return &Writer{w, 0, 0, nil} }
    19  
    20  // flush writes as much full bytes as possible to the underlying writer
    21  func (w *Writer) flush() {
    22  	if w.err != nil {
    23  		w.nbits = 0
    24  		return
    25  	}
    26  
    27  	var buf [16]byte
    28  	n := 0
    29  	for w.nbits > 8 {
    30  		buf[n] = byte(w.bits)
    31  		w.bits >>= 8
    32  		w.nbits -= 8
    33  		n++
    34  	}
    35  
    36  	_, w.err = w.w.Write(buf[0:n])
    37  }
    38  
    39  // flushAll writes all the remaining half bytes
    40  func (w *Writer) flushAll() {
    41  	w.flush()
    42  	if w.err != nil {
    43  		w.nbits = 0
    44  		return
    45  	}
    46  
    47  	if w.nbits > 0 {
    48  		_, w.err = w.w.Write([]byte{byte(w.bits)})
    49  		w.bits = 0
    50  		w.nbits = 0
    51  	}
    52  }
    53  
    54  func (w *Writer) Error() error { return w.err }
    55  
    56  // Align aligns the writer to the next byte
    57  func (w *Writer) Align() { w.flushAll() }
    58  
    59  // WriteBits writes width lowest bits to the underlying writer
    60  func (w *Writer) WriteBits(x uint64, width uint) {
    61  	if width > 32 {
    62  		w.WriteBits(uint64(uint32(x)), 32)
    63  		x >>= 32
    64  		width -= 32
    65  	}
    66  
    67  	x &= 1<<width - 1
    68  	w.bits |= x << w.nbits
    69  	w.nbits += width
    70  	if w.nbits > 16 {
    71  		w.flush()
    72  	}
    73  }
    74  
    75  // WriteBit writes the lowest bit in x to the underlying writer
    76  func (w *Writer) WriteBit(x int) {
    77  	w.WriteBits(uint64(x&1), 1)
    78  }
    79  
    80  // WriteBool writes a bool the underlying writer depending on x
    81  func (w *Writer) WriteBool(x bool) {
    82  	if x {
    83  		w.WriteBits(1, 1)
    84  	}
    85  	w.WriteBits(0, 1)
    86  }
    87  
    88  // WriteByte writes a byte to the underlying writer
    89  func (w *Writer) WriteByte(v byte) {
    90  	w.WriteBits(uint64(v), 8)
    91  }
    92  
    93  func (w *Writer) Close() error {
    94  	w.Align()
    95  	return w.err
    96  }
    97  
    98  type Reader struct {
    99  	r     io.Reader
   100  	bits  uint64
   101  	nbits uint
   102  	err   error
   103  }
   104  
   105  func NewReader(r io.Reader) *Reader {
   106  	return &Reader{r, 0, 8, nil}
   107  }
   108  
   109  // read reads a single byte from the underlying reader
   110  func (r *Reader) read() {
   111  	if r.err != nil {
   112  		r.nbits = 8
   113  		return
   114  	}
   115  
   116  	var temp [1]byte
   117  	_, r.err = r.r.Read(temp[:])
   118  	r.bits = uint64(temp[0])
   119  }
   120  
   121  func (r *Reader) Error() error { return r.err }
   122  
   123  // Align aligns the reader to the next byte so that the next ReadBits will start
   124  // reading a new byte from the underlying reader
   125  func (r *Reader) Align() {
   126  	r.nbits = 8
   127  }
   128  
   129  // ReadBits reads width bits from the underlying reader
   130  func (r *Reader) ReadBits(width uint) uint64 {
   131  	if r.err != nil {
   132  		return 0
   133  	}
   134  
   135  	left := 8 - int(r.nbits)
   136  	if left > int(width) {
   137  		mask := uint64((1 << width) - 1)
   138  		x := r.bits >> r.nbits
   139  		r.nbits += width
   140  		return x & mask
   141  	}
   142  
   143  	n := 8 - r.nbits
   144  	x := r.bits >> r.nbits
   145  	for int(width)-int(n) > 0 {
   146  		r.read()
   147  		r.nbits -= 8
   148  		if r.err != nil {
   149  			return 0
   150  		}
   151  		x |= r.bits << n
   152  		n += 8
   153  	}
   154  	r.nbits += width
   155  	mask := uint64(1<<width - 1)
   156  	return x & mask
   157  }
   158  
   159  // ReadBit reads a single bit from the underlying reader
   160  func (r *Reader) ReadBit() int { return int(r.ReadBits(1)) }
   161  
   162  // ReadBool reads a single bool from the underlying reader
   163  func (r *Reader) ReadBool() bool { return r.ReadBits(1) == 1 }
   164  
   165  // ReadByte reads a single bit from the underlying reader
   166  func (r *Reader) ReadByte() byte { return byte(r.ReadBits(8)) }