github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/coder/arith/models.go (about)

     1  package arith
     2  
     3  import "github.com/egonelbre/exp/bit"
     4  
     5  type Model interface {
     6  	NBits() uint
     7  	Encode(enc *Encoder, value uint)
     8  	Decode(dec *Decoder) (value uint)
     9  }
    10  
    11  type Shift struct {
    12  	P P
    13  	I byte
    14  }
    15  
    16  func (m *Shift) NBits() uint { return 1 }
    17  
    18  func (m *Shift) adapt(bit uint) {
    19  	switch bit {
    20  	case 1:
    21  		m.P += (MaxP - m.P) >> m.I
    22  	case 0:
    23  		m.P -= m.P >> m.I
    24  	}
    25  }
    26  
    27  func (m *Shift) Encode(enc *Encoder, bit uint) {
    28  	enc.Encode(bit, m.P)
    29  	m.adapt(bit)
    30  }
    31  
    32  func (m *Shift) Decode(dec *Decoder) (bit uint) {
    33  	bit = dec.Decode(m.P)
    34  	m.adapt(bit)
    35  	return bit
    36  }
    37  
    38  type Shift2 struct {
    39  	P0 P
    40  	I0 byte
    41  
    42  	P1 P
    43  	I1 byte
    44  }
    45  
    46  func (m *Shift2) NBits() uint { return 1 }
    47  
    48  func (m *Shift2) adapt(bit uint) {
    49  	switch bit {
    50  	case 1:
    51  		m.P0 += (MaxP/2 - m.P0) >> m.I0
    52  		m.P1 += (MaxP/2 - m.P1) >> m.I1
    53  	case 0:
    54  		m.P0 -= m.P0 >> m.I0
    55  		m.P1 -= m.P1 >> m.I1
    56  	}
    57  }
    58  
    59  func (m *Shift2) Encode(enc *Encoder, bit uint) {
    60  	enc.Encode(bit, m.P0+m.P1)
    61  	m.adapt(bit)
    62  }
    63  
    64  func (m *Shift2) Decode(dec *Decoder) (bit uint) {
    65  	bit = dec.Decode(m.P0 + m.P1)
    66  	m.adapt(bit)
    67  	return bit
    68  }
    69  
    70  type Shift4 struct {
    71  	P [4]P
    72  	I [4]byte
    73  }
    74  
    75  func (m *Shift4) NBits() uint { return 1 }
    76  
    77  func (m *Shift4) adapt(bit uint) {
    78  	switch bit {
    79  	case 1:
    80  		m.P[0] += (MaxP/4 - m.P[0]) >> m.I[0]
    81  		m.P[1] += (MaxP/4 - m.P[1]) >> m.I[1]
    82  		m.P[2] += (MaxP/4 - m.P[2]) >> m.I[2]
    83  		m.P[3] += (MaxP/4 - m.P[3]) >> m.I[3]
    84  	case 0:
    85  		m.P[0] -= m.P[0] >> m.I[0]
    86  		m.P[1] -= m.P[1] >> m.I[1]
    87  		m.P[2] -= m.P[2] >> m.I[2]
    88  		m.P[3] -= m.P[3] >> m.I[3]
    89  	}
    90  }
    91  
    92  func (m *Shift4) Encode(enc *Encoder, bit uint) {
    93  	enc.Encode(bit, m.P[0]+m.P[1]+m.P[2]+m.P[3])
    94  	m.adapt(bit)
    95  }
    96  
    97  func (m *Shift4) Decode(dec *Decoder) (bit uint) {
    98  	bit = dec.Decode(m.P[0] + m.P[1] + m.P[2] + m.P[3])
    99  	m.adapt(bit)
   100  	return bit
   101  }
   102  
   103  type Tree []Model
   104  
   105  func (tree Tree) NBits() uint { return bit.ScanRight(uint64(tree.syms())) }
   106  
   107  func (tree Tree) syms() uint { return uint(len(tree) + 1) }
   108  func (tree Tree) msb() uint  { return tree.syms() / 2 }
   109  
   110  func NewTree(nbits uint, model func() Model) Tree {
   111  	syms := 1 << nbits
   112  	tree := make(Tree, syms-1)
   113  	for i := range tree {
   114  		tree[i] = model()
   115  	}
   116  	return tree
   117  }
   118  
   119  func NewEmptyTree(nbits uint) Tree {
   120  	return make(Tree, 1<<nbits-1)
   121  }
   122  
   123  func (tree Tree) Encode(enc *Encoder, value uint) {
   124  	if value > tree.syms() {
   125  		panic("")
   126  	}
   127  
   128  	syms, msb := tree.syms(), tree.msb()
   129  	ctx := uint(1)
   130  	for ctx < syms {
   131  		bit := uint(0)
   132  		if value&msb != 0 {
   133  			bit = 1
   134  		}
   135  
   136  		value += value
   137  		tree[ctx-1].Encode(enc, bit)
   138  		ctx += ctx + bit
   139  	}
   140  }
   141  
   142  func (tree Tree) Decode(dec *Decoder) (value uint) {
   143  	ctx := uint(1)
   144  	syms := tree.syms()
   145  	for ctx < syms {
   146  		ctx += ctx + tree[ctx-1].Decode(dec)
   147  	}
   148  	return ctx - syms
   149  }