github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/coder/arith/encoding.go (about) 1 // This package implements simple arithmetic coders 2 // based on code by https://github.com/rygorous/mini_arith 3 package arith 4 5 type Encoder struct { 6 lo, hi uint32 7 data []byte 8 } 9 10 func NewEncoder() *Encoder { return &Encoder{lo: 0, hi: ^uint32(0)} } 11 12 func (p P) midpoint32(lo, hi uint32) uint32 { 13 return lo + uint32(uint64(hi-lo)*uint64(p)>>probBits) 14 } 15 16 func (enc *Encoder) Encode(bit uint, prob P) { 17 x := prob.midpoint32(enc.lo, enc.hi) 18 19 if bit == 1 { 20 enc.hi = x 21 } else { 22 enc.lo = x + 1 23 } 24 25 for enc.lo^enc.hi < 1<<24 { 26 enc.data = append(enc.data, byte(enc.lo>>24)) 27 enc.lo <<= 8 28 enc.hi = enc.hi<<8 | 0xFF 29 } 30 } 31 32 func (enc *Encoder) Close() { 33 for i := 0; i < 4; i++ { 34 enc.data = append(enc.data, byte(enc.lo>>24)) 35 enc.lo <<= 8 36 } 37 } 38 39 func (enc *Encoder) Bytes() []byte { return enc.data } 40 41 type Decoder struct { 42 lo, hi uint32 43 code uint32 44 data []byte 45 read int 46 } 47 48 func NewDecoder(data []byte) *Decoder { 49 dec := &Decoder{lo: 0, hi: ^uint32(0)} 50 dec.data = data 51 52 for i := 0; i < 4; i++ { 53 dec.code = dec.code<<8 | uint32(dec.data[dec.read]) 54 dec.read++ 55 } 56 57 return dec 58 } 59 60 func (dec *Decoder) Decode(prob P) (bit uint) { 61 x := prob.midpoint32(dec.lo, dec.hi) 62 63 if dec.code <= x { 64 dec.hi = x 65 bit = 1 66 } else { 67 dec.lo = x + 1 68 bit = 0 69 } 70 71 for dec.lo^dec.hi < 1<<24 { 72 dec.code = dec.code<<8 | uint32(dec.data[dec.read]) 73 dec.read++ 74 dec.lo <<= 8 75 dec.hi = dec.hi<<8 | 0xff 76 } 77 78 return bit 79 }