github.com/agnivade/pgm@v0.0.0-20210528073050-e2df0d9cb72d/pgm_index_test.go (about) 1 package pgm 2 3 import ( 4 "fmt" 5 "math/rand" 6 "sort" 7 "testing" 8 "time" 9 ) 10 11 func TestPGMIndex(t *testing.T) { 12 testCases := []struct { 13 data []float64 14 segments int // segments in the last level 15 epsilon int 16 searchMem []struct { 17 k float64 18 present bool 19 } 20 searchPred []float64 21 }{ 22 { 23 data: []float64{2, 12, 15, 18, 23, 24, 29, 31, 34, 36, 38, 48, 55, 59, 60, 71, 73, 74, 76, 88, 95, 102, 115, 122, 123, 124, 158, 159, 161, 164, 165, 187, 189, 190}, 24 segments: 5, 25 epsilon: 1, 26 searchMem: []struct { 27 k float64 28 present bool 29 }{ 30 {k: 31, present: true}, 31 {k: 32, present: false}, 32 {k: 80, present: false}, 33 {k: 95, present: true}, 34 {k: 190, present: true}, 35 {k: 200, present: false}, 36 }, 37 searchPred: []float64{1, 16, 15, 40, 48, 100, 200}, 38 }, 39 { 40 data: []float64{21, 24, 46, 50, 52, 108, 109, 141, 147, 152, 178, 185, 275, 282, 310, 324, 332, 373, 380, 415, 433, 442, 452, 471, 476, 496}, 41 segments: 3, 42 epsilon: 1, 43 }, 44 { 45 data: []float64{1, 2, 13, 36, 37, 57, 69, 107, 140, 176, 215, 229, 246, 260, 288, 324, 337, 341, 381, 390, 409, 411, 416, 442, 444, 453, 476, 497}, 46 segments: 3, 47 epsilon: 1, 48 }, 49 { 50 data: []float64{11, 28, 119, 131, 167, 345, 348, 362, 369, 439}, 51 segments: 2, 52 epsilon: 1, 53 }, 54 } 55 56 for _, tc := range testCases { 57 sort.Float64s(tc.data) 58 59 ind := NewIndex(tc.data, tc.epsilon) 60 // for _, level := range ind.levels { 61 // fmt.Println(level) 62 // } 63 64 if len(ind.levels[0]) != tc.segments { 65 t.Errorf("incorrect number of segments. Got: %d, Want: %d", len(ind.levels[0]), tc.segments) 66 } 67 verifyIndex(t, ind, tc.data, tc.epsilon) 68 69 if tc.searchMem != nil { 70 for _, mem := range tc.searchMem { 71 pos, err := ind.Search(mem.k) 72 if err != nil { 73 t.Errorf("error received: %v", err) 74 } 75 76 found := false 77 for _, d := range tc.data[pos.Lo : pos.Hi+1] { 78 if d == mem.k { 79 found = true 80 break 81 } 82 } 83 84 if found != mem.present { 85 t.Errorf("incorrect membership result for %f. Got %t, Want: %t", mem.k, found, mem.present) 86 } 87 } 88 } 89 90 if tc.searchPred != nil { 91 for _, pred := range tc.searchPred { 92 pos, err := ind.Search(pred) 93 if err != nil { 94 t.Errorf("error received: %v", err) 95 } 96 t.Log(pred, tc.data[pos.Lo], tc.data[pos.Hi]) 97 98 if pos.Lo >= pos.Hi { 99 t.Errorf("lo %d is greater than hi %d", pos.Lo, pos.Hi) 100 } 101 // pos.Lo <= k <= pos.Hi 102 if tc.data[pos.Lo] > pred && pos.Lo != 0 { 103 t.Errorf("lo %f is greater than k %f", tc.data[pos.Lo], pred) 104 } 105 if tc.data[pos.Hi] < pred && pos.Hi != len(tc.data)-1 { 106 t.Errorf("hi %f is lesser than k %f", tc.data[pos.Hi], pred) 107 } 108 } 109 } 110 } 111 } 112 113 func verifyIndex(t *testing.T, ind *Index, input []float64, epsilon int) { 114 verifySegment := func(t *testing.T, set []float64, start, end int, s Segment) { 115 t.Helper() 116 // Iterate all points in the segment and verify they are within e. 117 for i := start; i < end; i++ { 118 // (mx+c) - i = err 119 err := (s.slope*set[i] + s.intercept) - float64(i) 120 if err > float64(2*epsilon) { 121 t.Errorf("error threshold exceeded, x: %d, y:%f", i, set[i]) 122 } 123 } 124 } 125 126 // Verify each level 127 for i := 0; i < len(ind.levels); i++ { 128 level := ind.levels[i] 129 var set []float64 130 if i == 0 { 131 set = input 132 } else { 133 set = set[:0] // reset 134 for _, seg := range ind.levels[i-1] { 135 set = append(set, seg.key) 136 } 137 } 138 // Find set of keys for each segment 139 current := 0 140 for j := 0; j < len(level); j++ { 141 s := level[j] 142 // Check if last segment or not 143 if j+1 == len(level) { 144 verifySegment(t, set, current, len(set), s) 145 } else { 146 nextKey := level[j+1].key 147 start := current 148 for set[current] != nextKey { 149 current++ 150 } 151 verifySegment(t, set, start, current, s) 152 } 153 } 154 } 155 } 156 157 func TestGendata(t *testing.T) { 158 t.Skip("only to generate corpus") 159 var input []int 160 rand.Seed(time.Now().UnixNano()) 161 for i := 0; i < 20; i++ { 162 input = append(input, rand.Intn(50)) 163 } 164 sort.Ints(input) 165 input = removeDups(input) 166 fmt.Println(input) 167 } 168 169 func removeDups(in []int) []int { 170 j := 0 171 for i := 1; i < len(in); i++ { 172 if in[j] == in[i] { 173 continue 174 } 175 j++ 176 in[j] = in[i] 177 } 178 return in[:j+1] 179 }