github.com/agnivade/pgm@v0.0.0-20210528073050-e2df0d9cb72d/pgm_index_test.go (about)

     1  package pgm
     2  
     3  import (
     4  	"fmt"
     5  	"math/rand"
     6  	"sort"
     7  	"testing"
     8  	"time"
     9  )
    10  
    11  func TestPGMIndex(t *testing.T) {
    12  	testCases := []struct {
    13  		data      []float64
    14  		segments  int // segments in the last level
    15  		epsilon   int
    16  		searchMem []struct {
    17  			k       float64
    18  			present bool
    19  		}
    20  		searchPred []float64
    21  	}{
    22  		{
    23  			data:     []float64{2, 12, 15, 18, 23, 24, 29, 31, 34, 36, 38, 48, 55, 59, 60, 71, 73, 74, 76, 88, 95, 102, 115, 122, 123, 124, 158, 159, 161, 164, 165, 187, 189, 190},
    24  			segments: 5,
    25  			epsilon:  1,
    26  			searchMem: []struct {
    27  				k       float64
    28  				present bool
    29  			}{
    30  				{k: 31, present: true},
    31  				{k: 32, present: false},
    32  				{k: 80, present: false},
    33  				{k: 95, present: true},
    34  				{k: 190, present: true},
    35  				{k: 200, present: false},
    36  			},
    37  			searchPred: []float64{1, 16, 15, 40, 48, 100, 200},
    38  		},
    39  		{
    40  			data:     []float64{21, 24, 46, 50, 52, 108, 109, 141, 147, 152, 178, 185, 275, 282, 310, 324, 332, 373, 380, 415, 433, 442, 452, 471, 476, 496},
    41  			segments: 3,
    42  			epsilon:  1,
    43  		},
    44  		{
    45  			data:     []float64{1, 2, 13, 36, 37, 57, 69, 107, 140, 176, 215, 229, 246, 260, 288, 324, 337, 341, 381, 390, 409, 411, 416, 442, 444, 453, 476, 497},
    46  			segments: 3,
    47  			epsilon:  1,
    48  		},
    49  		{
    50  			data:     []float64{11, 28, 119, 131, 167, 345, 348, 362, 369, 439},
    51  			segments: 2,
    52  			epsilon:  1,
    53  		},
    54  	}
    55  
    56  	for _, tc := range testCases {
    57  		sort.Float64s(tc.data)
    58  
    59  		ind := NewIndex(tc.data, tc.epsilon)
    60  		// for _, level := range ind.levels {
    61  		// 	fmt.Println(level)
    62  		// }
    63  
    64  		if len(ind.levels[0]) != tc.segments {
    65  			t.Errorf("incorrect number of segments. Got: %d, Want: %d", len(ind.levels[0]), tc.segments)
    66  		}
    67  		verifyIndex(t, ind, tc.data, tc.epsilon)
    68  
    69  		if tc.searchMem != nil {
    70  			for _, mem := range tc.searchMem {
    71  				pos, err := ind.Search(mem.k)
    72  				if err != nil {
    73  					t.Errorf("error received: %v", err)
    74  				}
    75  
    76  				found := false
    77  				for _, d := range tc.data[pos.Lo : pos.Hi+1] {
    78  					if d == mem.k {
    79  						found = true
    80  						break
    81  					}
    82  				}
    83  
    84  				if found != mem.present {
    85  					t.Errorf("incorrect membership result for %f. Got %t, Want: %t", mem.k, found, mem.present)
    86  				}
    87  			}
    88  		}
    89  
    90  		if tc.searchPred != nil {
    91  			for _, pred := range tc.searchPred {
    92  				pos, err := ind.Search(pred)
    93  				if err != nil {
    94  					t.Errorf("error received: %v", err)
    95  				}
    96  				t.Log(pred, tc.data[pos.Lo], tc.data[pos.Hi])
    97  
    98  				if pos.Lo >= pos.Hi {
    99  					t.Errorf("lo %d is greater than hi %d", pos.Lo, pos.Hi)
   100  				}
   101  				// pos.Lo <= k <= pos.Hi
   102  				if tc.data[pos.Lo] > pred && pos.Lo != 0 {
   103  					t.Errorf("lo %f is greater than k %f", tc.data[pos.Lo], pred)
   104  				}
   105  				if tc.data[pos.Hi] < pred && pos.Hi != len(tc.data)-1 {
   106  					t.Errorf("hi %f is lesser than k %f", tc.data[pos.Hi], pred)
   107  				}
   108  			}
   109  		}
   110  	}
   111  }
   112  
   113  func verifyIndex(t *testing.T, ind *Index, input []float64, epsilon int) {
   114  	verifySegment := func(t *testing.T, set []float64, start, end int, s Segment) {
   115  		t.Helper()
   116  		// Iterate all points in the segment and verify they are within e.
   117  		for i := start; i < end; i++ {
   118  			// (mx+c) - i = err
   119  			err := (s.slope*set[i] + s.intercept) - float64(i)
   120  			if err > float64(2*epsilon) {
   121  				t.Errorf("error threshold exceeded, x: %d, y:%f", i, set[i])
   122  			}
   123  		}
   124  	}
   125  
   126  	// Verify each level
   127  	for i := 0; i < len(ind.levels); i++ {
   128  		level := ind.levels[i]
   129  		var set []float64
   130  		if i == 0 {
   131  			set = input
   132  		} else {
   133  			set = set[:0] // reset
   134  			for _, seg := range ind.levels[i-1] {
   135  				set = append(set, seg.key)
   136  			}
   137  		}
   138  		// Find set of keys for each segment
   139  		current := 0
   140  		for j := 0; j < len(level); j++ {
   141  			s := level[j]
   142  			// Check if last segment or not
   143  			if j+1 == len(level) {
   144  				verifySegment(t, set, current, len(set), s)
   145  			} else {
   146  				nextKey := level[j+1].key
   147  				start := current
   148  				for set[current] != nextKey {
   149  					current++
   150  				}
   151  				verifySegment(t, set, start, current, s)
   152  			}
   153  		}
   154  	}
   155  }
   156  
   157  func TestGendata(t *testing.T) {
   158  	t.Skip("only to generate corpus")
   159  	var input []int
   160  	rand.Seed(time.Now().UnixNano())
   161  	for i := 0; i < 20; i++ {
   162  		input = append(input, rand.Intn(50))
   163  	}
   164  	sort.Ints(input)
   165  	input = removeDups(input)
   166  	fmt.Println(input)
   167  }
   168  
   169  func removeDups(in []int) []int {
   170  	j := 0
   171  	for i := 1; i < len(in); i++ {
   172  		if in[j] == in[i] {
   173  			continue
   174  		}
   175  		j++
   176  		in[j] = in[i]
   177  	}
   178  	return in[:j+1]
   179  }