github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/coder/arith/models.go (about) 1 package arith 2 3 import "github.com/egonelbre/exp/bit" 4 5 type Model interface { 6 NBits() uint 7 Encode(enc *Encoder, value uint) 8 Decode(dec *Decoder) (value uint) 9 } 10 11 type Shift struct { 12 P P 13 I byte 14 } 15 16 func (m *Shift) NBits() uint { return 1 } 17 18 func (m *Shift) adapt(bit uint) { 19 switch bit { 20 case 1: 21 m.P += (MaxP - m.P) >> m.I 22 case 0: 23 m.P -= m.P >> m.I 24 } 25 } 26 27 func (m *Shift) Encode(enc *Encoder, bit uint) { 28 enc.Encode(bit, m.P) 29 m.adapt(bit) 30 } 31 32 func (m *Shift) Decode(dec *Decoder) (bit uint) { 33 bit = dec.Decode(m.P) 34 m.adapt(bit) 35 return bit 36 } 37 38 type Shift2 struct { 39 P0 P 40 I0 byte 41 42 P1 P 43 I1 byte 44 } 45 46 func (m *Shift2) NBits() uint { return 1 } 47 48 func (m *Shift2) adapt(bit uint) { 49 switch bit { 50 case 1: 51 m.P0 += (MaxP/2 - m.P0) >> m.I0 52 m.P1 += (MaxP/2 - m.P1) >> m.I1 53 case 0: 54 m.P0 -= m.P0 >> m.I0 55 m.P1 -= m.P1 >> m.I1 56 } 57 } 58 59 func (m *Shift2) Encode(enc *Encoder, bit uint) { 60 enc.Encode(bit, m.P0+m.P1) 61 m.adapt(bit) 62 } 63 64 func (m *Shift2) Decode(dec *Decoder) (bit uint) { 65 bit = dec.Decode(m.P0 + m.P1) 66 m.adapt(bit) 67 return bit 68 } 69 70 type Shift4 struct { 71 P [4]P 72 I [4]byte 73 } 74 75 func (m *Shift4) NBits() uint { return 1 } 76 77 func (m *Shift4) adapt(bit uint) { 78 switch bit { 79 case 1: 80 m.P[0] += (MaxP/4 - m.P[0]) >> m.I[0] 81 m.P[1] += (MaxP/4 - m.P[1]) >> m.I[1] 82 m.P[2] += (MaxP/4 - m.P[2]) >> m.I[2] 83 m.P[3] += (MaxP/4 - m.P[3]) >> m.I[3] 84 case 0: 85 m.P[0] -= m.P[0] >> m.I[0] 86 m.P[1] -= m.P[1] >> m.I[1] 87 m.P[2] -= m.P[2] >> m.I[2] 88 m.P[3] -= m.P[3] >> m.I[3] 89 } 90 } 91 92 func (m *Shift4) Encode(enc *Encoder, bit uint) { 93 enc.Encode(bit, m.P[0]+m.P[1]+m.P[2]+m.P[3]) 94 m.adapt(bit) 95 } 96 97 func (m *Shift4) Decode(dec *Decoder) (bit uint) { 98 bit = dec.Decode(m.P[0] + m.P[1] + m.P[2] + m.P[3]) 99 m.adapt(bit) 100 return bit 101 } 102 103 type Tree []Model 104 105 func (tree Tree) NBits() uint { return bit.ScanRight(uint64(tree.syms())) } 106 107 func (tree Tree) syms() uint { return uint(len(tree) + 1) } 108 func (tree Tree) msb() uint { return tree.syms() / 2 } 109 110 func NewTree(nbits uint, model func() Model) Tree { 111 syms := 1 << nbits 112 tree := make(Tree, syms-1) 113 for i := range tree { 114 tree[i] = model() 115 } 116 return tree 117 } 118 119 func NewEmptyTree(nbits uint) Tree { 120 return make(Tree, 1<<nbits-1) 121 } 122 123 func (tree Tree) Encode(enc *Encoder, value uint) { 124 if value > tree.syms() { 125 panic("") 126 } 127 128 syms, msb := tree.syms(), tree.msb() 129 ctx := uint(1) 130 for ctx < syms { 131 bit := uint(0) 132 if value&msb != 0 { 133 bit = 1 134 } 135 136 value += value 137 tree[ctx-1].Encode(enc, bit) 138 ctx += ctx + bit 139 } 140 } 141 142 func (tree Tree) Decode(dec *Decoder) (value uint) { 143 ctx := uint(1) 144 syms := tree.syms() 145 for ctx < syms { 146 ctx += ctx + tree[ctx-1].Decode(dec) 147 } 148 return ctx - syms 149 }