github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/coder/arith/models_test.go (about) 1 package arith 2 3 import ( 4 "math/rand" 5 "testing" 6 ) 7 8 // Tests a model provided by the constructor 9 func ModelTest(t *testing.T, N int, model func() Model) { 10 encm, decm := model(), model() 11 mask := uint(1<<encm.NBits() - 1) 12 13 bits := make([]uint, N) 14 for i := range bits { 15 bits[i] = uint(rand.Int()) & mask 16 } 17 18 enc := NewEncoder() 19 for _, bit := range bits { 20 encm.Encode(enc, bit) 21 } 22 enc.Close() 23 24 dec := NewDecoder(enc.Bytes()) 25 for i, v := range bits { 26 x := decm.Decode(dec) & mask 27 if x != v { 28 t.Fatalf("fail %v: %v got %v exp %v", bits, i, x, v) 29 } 30 } 31 } 32 33 func TestShiftModel(t *testing.T) { 34 ModelTest(t, 3, func() Model { return &Shift{Prob(0.5), 3} }) 35 ModelTest(t, 21, func() Model { return &Shift{Prob(0.1), 2} }) 36 ModelTest(t, 312532, func() Model { return &Shift{Prob(0.9), 1} }) 37 } 38 39 func TestShift2Model(t *testing.T) { 40 ModelTest(t, 3, func() Model { return &Shift2{Prob(0.5), 3, Prob(0.1), 1} }) 41 ModelTest(t, 21, func() Model { return &Shift2{Prob(0.25), 2, Prob(0.25), 1} }) 42 ModelTest(t, 312532, func() Model { return &Shift2{Prob(0.9), 1, Prob(0.5), 1} }) 43 } 44 45 func TestTreeModel(t *testing.T) { 46 B := func() Model { return &Shift{Prob(0.43), 3} } 47 ModelTest(t, 3, func() Model { return NewTree(1, B) }) 48 ModelTest(t, 3, func() Model { return NewTree(2, B) }) 49 ModelTest(t, 21, func() Model { return NewTree(2, B) }) 50 ModelTest(t, 4124, func() Model { return NewTree(3, B) }) 51 ModelTest(t, 41244, func() Model { return NewTree(3, B) }) 52 53 ModelTest(t, 41244, func() Model { 54 m := NewEmptyTree(2) 55 m[0] = &Shift{Prob(0.43), 3} 56 m[1] = &Shift{Prob(0.9), 2} 57 m[2] = &Shift2{Prob(0.5), 3, Prob(0.1), 1} 58 return m 59 }) 60 } 61 62 func TestByteModel(t *testing.T) { 63 B := func() Model { return &Shift{Prob(0.5), 5} } 64 T := func() Model { return NewTree(8, B) } 65 66 encm, decm := T(), T() 67 68 data := make([]byte, 25192) 69 for i := range data { 70 data[i] = byte(rand.Int()) 71 } 72 73 enc := NewEncoder() 74 for _, b := range data { 75 encm.Encode(enc, uint(b)) 76 } 77 enc.Close() 78 79 dec := NewDecoder(enc.Bytes()) 80 for i, v := range data { 81 x := decm.Decode(dec) 82 if byte(x) != v { 83 t.Fatalf("fail %v: got %v exp %v", i, x, v) 84 } 85 } 86 } 87 88 func TestShiftZeros(t *testing.T) { 89 model := func() Model { return &Shift{MaxP / 1000, 7} } 90 91 encm, decm := model(), model() 92 const N = 900 93 94 enc := NewEncoder() 95 encm.Encode(enc, 1) 96 for i := 0; i < N; i += 1 { 97 encm.Encode(enc, 0) 98 } 99 enc.Close() 100 101 dec := NewDecoder(enc.Bytes()) 102 v := decm.Decode(dec) 103 if v != 1 { 104 t.Fatalf("0:got %v expected 1", v) 105 } 106 107 for i := 0; i < N; i += 1 { 108 v = decm.Decode(dec) 109 if v != 0 { 110 t.Fatalf("%d: got %v expected 0", i+1, v) 111 } 112 } 113 } 114 115 func TestShift2Zeros(t *testing.T) { 116 model := func() Model { return &Shift2{P0: 0x500, I0: 0x1, P1: 0x150, I1: 0x5} } 117 118 encm, decm := model(), model() 119 const N = 1e5 120 121 enc := NewEncoder() 122 encm.Encode(enc, 1) 123 for i := 0; i < N; i += 1 { 124 encm.Encode(enc, 0) 125 } 126 enc.Close() 127 128 dec := NewDecoder(enc.Bytes()) 129 v := decm.Decode(dec) 130 if v != 1 { 131 t.Fatalf("0:got %v expected 1", v) 132 } 133 134 for i := 0; i < N; i += 1 { 135 v = decm.Decode(dec) 136 if v != 0 { 137 t.Fatalf("%d: got %v expected 0", i+1, v) 138 } 139 } 140 }