github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/coder/arith/models_test.go (about)

     1  package arith
     2  
     3  import (
     4  	"math/rand"
     5  	"testing"
     6  )
     7  
     8  // Tests a model provided by the constructor
     9  func ModelTest(t *testing.T, N int, model func() Model) {
    10  	encm, decm := model(), model()
    11  	mask := uint(1<<encm.NBits() - 1)
    12  
    13  	bits := make([]uint, N)
    14  	for i := range bits {
    15  		bits[i] = uint(rand.Int()) & mask
    16  	}
    17  
    18  	enc := NewEncoder()
    19  	for _, bit := range bits {
    20  		encm.Encode(enc, bit)
    21  	}
    22  	enc.Close()
    23  
    24  	dec := NewDecoder(enc.Bytes())
    25  	for i, v := range bits {
    26  		x := decm.Decode(dec) & mask
    27  		if x != v {
    28  			t.Fatalf("fail %v: %v got %v exp %v", bits, i, x, v)
    29  		}
    30  	}
    31  }
    32  
    33  func TestShiftModel(t *testing.T) {
    34  	ModelTest(t, 3, func() Model { return &Shift{Prob(0.5), 3} })
    35  	ModelTest(t, 21, func() Model { return &Shift{Prob(0.1), 2} })
    36  	ModelTest(t, 312532, func() Model { return &Shift{Prob(0.9), 1} })
    37  }
    38  
    39  func TestShift2Model(t *testing.T) {
    40  	ModelTest(t, 3, func() Model { return &Shift2{Prob(0.5), 3, Prob(0.1), 1} })
    41  	ModelTest(t, 21, func() Model { return &Shift2{Prob(0.25), 2, Prob(0.25), 1} })
    42  	ModelTest(t, 312532, func() Model { return &Shift2{Prob(0.9), 1, Prob(0.5), 1} })
    43  }
    44  
    45  func TestTreeModel(t *testing.T) {
    46  	B := func() Model { return &Shift{Prob(0.43), 3} }
    47  	ModelTest(t, 3, func() Model { return NewTree(1, B) })
    48  	ModelTest(t, 3, func() Model { return NewTree(2, B) })
    49  	ModelTest(t, 21, func() Model { return NewTree(2, B) })
    50  	ModelTest(t, 4124, func() Model { return NewTree(3, B) })
    51  	ModelTest(t, 41244, func() Model { return NewTree(3, B) })
    52  
    53  	ModelTest(t, 41244, func() Model {
    54  		m := NewEmptyTree(2)
    55  		m[0] = &Shift{Prob(0.43), 3}
    56  		m[1] = &Shift{Prob(0.9), 2}
    57  		m[2] = &Shift2{Prob(0.5), 3, Prob(0.1), 1}
    58  		return m
    59  	})
    60  }
    61  
    62  func TestByteModel(t *testing.T) {
    63  	B := func() Model { return &Shift{Prob(0.5), 5} }
    64  	T := func() Model { return NewTree(8, B) }
    65  
    66  	encm, decm := T(), T()
    67  
    68  	data := make([]byte, 25192)
    69  	for i := range data {
    70  		data[i] = byte(rand.Int())
    71  	}
    72  
    73  	enc := NewEncoder()
    74  	for _, b := range data {
    75  		encm.Encode(enc, uint(b))
    76  	}
    77  	enc.Close()
    78  
    79  	dec := NewDecoder(enc.Bytes())
    80  	for i, v := range data {
    81  		x := decm.Decode(dec)
    82  		if byte(x) != v {
    83  			t.Fatalf("fail %v: got %v exp %v", i, x, v)
    84  		}
    85  	}
    86  }
    87  
    88  func TestShiftZeros(t *testing.T) {
    89  	model := func() Model { return &Shift{MaxP / 1000, 7} }
    90  
    91  	encm, decm := model(), model()
    92  	const N = 900
    93  
    94  	enc := NewEncoder()
    95  	encm.Encode(enc, 1)
    96  	for i := 0; i < N; i += 1 {
    97  		encm.Encode(enc, 0)
    98  	}
    99  	enc.Close()
   100  
   101  	dec := NewDecoder(enc.Bytes())
   102  	v := decm.Decode(dec)
   103  	if v != 1 {
   104  		t.Fatalf("0:got %v expected 1", v)
   105  	}
   106  
   107  	for i := 0; i < N; i += 1 {
   108  		v = decm.Decode(dec)
   109  		if v != 0 {
   110  			t.Fatalf("%d: got %v expected 0", i+1, v)
   111  		}
   112  	}
   113  }
   114  
   115  func TestShift2Zeros(t *testing.T) {
   116  	model := func() Model { return &Shift2{P0: 0x500, I0: 0x1, P1: 0x150, I1: 0x5} }
   117  
   118  	encm, decm := model(), model()
   119  	const N = 1e5
   120  
   121  	enc := NewEncoder()
   122  	encm.Encode(enc, 1)
   123  	for i := 0; i < N; i += 1 {
   124  		encm.Encode(enc, 0)
   125  	}
   126  	enc.Close()
   127  
   128  	dec := NewDecoder(enc.Bytes())
   129  	v := decm.Decode(dec)
   130  	if v != 1 {
   131  		t.Fatalf("0:got %v expected 1", v)
   132  	}
   133  
   134  	for i := 0; i < N; i += 1 {
   135  		v = decm.Decode(dec)
   136  		if v != 0 {
   137  			t.Fatalf("%d: got %v expected 0", i+1, v)
   138  		}
   139  	}
   140  }