github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/coder/arith/encoding.go (about)

     1  // This package implements simple arithmetic coders
     2  // based on code by https://github.com/rygorous/mini_arith
     3  package arith
     4  
     5  type Encoder struct {
     6  	lo, hi uint32
     7  	data   []byte
     8  }
     9  
    10  func NewEncoder() *Encoder { return &Encoder{lo: 0, hi: ^uint32(0)} }
    11  
    12  func (p P) midpoint32(lo, hi uint32) uint32 {
    13  	return lo + uint32(uint64(hi-lo)*uint64(p)>>probBits)
    14  }
    15  
    16  func (enc *Encoder) Encode(bit uint, prob P) {
    17  	x := prob.midpoint32(enc.lo, enc.hi)
    18  
    19  	if bit == 1 {
    20  		enc.hi = x
    21  	} else {
    22  		enc.lo = x + 1
    23  	}
    24  
    25  	for enc.lo^enc.hi < 1<<24 {
    26  		enc.data = append(enc.data, byte(enc.lo>>24))
    27  		enc.lo <<= 8
    28  		enc.hi = enc.hi<<8 | 0xFF
    29  	}
    30  }
    31  
    32  func (enc *Encoder) Close() {
    33  	for i := 0; i < 4; i++ {
    34  		enc.data = append(enc.data, byte(enc.lo>>24))
    35  		enc.lo <<= 8
    36  	}
    37  }
    38  
    39  func (enc *Encoder) Bytes() []byte { return enc.data }
    40  
    41  type Decoder struct {
    42  	lo, hi uint32
    43  	code   uint32
    44  	data   []byte
    45  	read   int
    46  }
    47  
    48  func NewDecoder(data []byte) *Decoder {
    49  	dec := &Decoder{lo: 0, hi: ^uint32(0)}
    50  	dec.data = data
    51  
    52  	for i := 0; i < 4; i++ {
    53  		dec.code = dec.code<<8 | uint32(dec.data[dec.read])
    54  		dec.read++
    55  	}
    56  
    57  	return dec
    58  }
    59  
    60  func (dec *Decoder) Decode(prob P) (bit uint) {
    61  	x := prob.midpoint32(dec.lo, dec.hi)
    62  
    63  	if dec.code <= x {
    64  		dec.hi = x
    65  		bit = 1
    66  	} else {
    67  		dec.lo = x + 1
    68  		bit = 0
    69  	}
    70  
    71  	for dec.lo^dec.hi < 1<<24 {
    72  		dec.code = dec.code<<8 | uint32(dec.data[dec.read])
    73  		dec.read++
    74  		dec.lo <<= 8
    75  		dec.hi = dec.hi<<8 | 0xff
    76  	}
    77  
    78  	return bit
    79  }