github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/minhash/minhash_test.go (about)

     1  package minhash
     2  
     3  import (
     4  	"fmt"
     5  	"hash/crc64"
     6  	"math/rand"
     7  	"reflect"
     8  	"slices"
     9  	"sort"
    10  	"testing"
    11  
    12  	"github.com/fluhus/gostuff/gnum"
    13  )
    14  
    15  func TestCollection(t *testing.T) {
    16  	tests := []struct {
    17  		n     int
    18  		input []uint64
    19  		want  []uint64
    20  	}{
    21  		{
    22  			3,
    23  			[]uint64{1, 2, 2, 2, 2, 1, 1, 3, 3, 3, 1, 2, 3, 1, 3, 3, 2},
    24  			[]uint64{1, 2, 3},
    25  		},
    26  		{
    27  			3,
    28  			[]uint64{1, 2, 3, 4, 5, 6, 7, 8, 9},
    29  			[]uint64{1, 2, 3},
    30  		},
    31  		{
    32  			3,
    33  			[]uint64{9, 8, 7, 6, 5, 4, 3, 2, 1},
    34  			[]uint64{1, 2, 3},
    35  		},
    36  		{
    37  			5,
    38  			[]uint64{40, 19, 55, 10, 32, 1, 100, 5, 99, 16, 16},
    39  			[]uint64{1, 5, 10, 16, 19},
    40  		},
    41  	}
    42  	for _, test := range tests {
    43  		mh := New[uint64](test.n)
    44  		for _, k := range test.input {
    45  			mh.Push(k)
    46  		}
    47  		got := mh.View()
    48  		sort.Slice(got, func(i, j int) bool {
    49  			return got[i] < got[j]
    50  		})
    51  		if !reflect.DeepEqual(got, test.want) {
    52  			t.Errorf("New(%d).Push(%v)=%v, want %v",
    53  				test.n, test.input, got, test.want)
    54  		}
    55  	}
    56  }
    57  
    58  func TestJSON(t *testing.T) {
    59  	input := New[int](5)
    60  	input.Push(1)
    61  	input.Push(4)
    62  	input.Push(9)
    63  	input.Push(16)
    64  	input.Push(25)
    65  	input.Push(36)
    66  	jsn, err := input.MarshalJSON()
    67  	if err != nil {
    68  		t.Fatalf("MinHash(1,4,9,16,25,36).MarshalJSON() failed: %v", err)
    69  	}
    70  	got := New[int](2)
    71  	err = got.UnmarshalJSON(jsn)
    72  	if err != nil {
    73  		t.Fatalf("UnmarshalJSON(%q) failed: %v", jsn, err)
    74  	}
    75  	if !slices.Equal(got.View(), input.View()) {
    76  		t.Fatalf("UnmarshalJSON(%q)=%v, want %v", jsn, got, input)
    77  	}
    78  }
    79  
    80  func TestJaccard(t *testing.T) {
    81  	tests := []struct {
    82  		a, b []uint64
    83  		k    int
    84  		want float64
    85  	}{
    86  		{[]uint64{1, 2, 3}, []uint64{1, 2, 3}, 3, 1},
    87  		{[]uint64{1, 2, 3}, []uint64{2, 3, 4}, 3, 2.0 / 3.0},
    88  		{[]uint64{2, 3, 4}, []uint64{1, 2, 3}, 3, 2.0 / 3.0},
    89  		{[]uint64{1, 2, 3, 4, 5}, []uint64{1, 3, 5}, 5, 0.6},
    90  	}
    91  	for _, test := range tests {
    92  		a, b := New[uint64](test.k), New[uint64](test.k)
    93  		for _, i := range test.a {
    94  			a.Push(i)
    95  		}
    96  		for _, i := range test.b {
    97  			b.Push(i)
    98  		}
    99  		a.Sort()
   100  		b.Sort()
   101  		if got := a.Jaccard(b); gnum.Abs(got-test.want) > 0.00001 {
   102  			t.Errorf("Jaccard(%v,%v)=%f, want %f",
   103  				test.a, test.b, got, test.want)
   104  		}
   105  	}
   106  }
   107  
   108  func TestCollection_largeInput(t *testing.T) {
   109  	const k = 10000
   110  	tests := []struct {
   111  		from1, to1, from2, to2 int
   112  	}{
   113  		{1, 75000, 25000, 100000},
   114  		{1, 60000, 40000, 60000},
   115  		{1, 60000, 20000, 60000},
   116  		{1, 40000, 40001, 60000},
   117  	}
   118  	for _, test := range tests {
   119  		a, b := New[uint64](k), New[uint64](k)
   120  		h := crc64.New(crc64.MakeTable(crc64.ECMA))
   121  		for i := test.from1; i <= test.to1; i++ {
   122  			h.Reset()
   123  			fmt.Fprint(h, i)
   124  			a.Push(h.Sum64())
   125  		}
   126  		for i := test.from2; i <= test.to2; i++ {
   127  			h.Reset()
   128  			fmt.Fprint(h, i)
   129  			b.Push(h.Sum64())
   130  		}
   131  		a.Sort()
   132  		b.Sort()
   133  		want := float64(test.to1-test.from2+1) / float64(
   134  			test.to2-test.from1+1)
   135  		if got := a.Jaccard(b); gnum.Abs(got-want) > want/100 {
   136  			t.Errorf("Jaccard(...)=%f, want %f", got, want)
   137  		}
   138  	}
   139  }
   140  
   141  func FuzzCollection(f *testing.F) {
   142  	f.Add(1, 2, 3, 4, 5, 6)
   143  	f.Fuzz(func(t *testing.T, a int, b int, c int, d int, e int, f int) {
   144  		col := New[int](2)
   145  		col.Push(a)
   146  		col.Push(b)
   147  		col.Push(c)
   148  		col.Push(d)
   149  		col.Push(e)
   150  		col.Push(f)
   151  		v := col.View()
   152  		if len(v) != 2 {
   153  			t.Errorf("len()=%d, want %d", len(v), 2)
   154  		}
   155  		if v[0] < v[1] {
   156  			t.Errorf("v[0]<v[1]: %d<%d, want >=", v[0], v[1])
   157  		}
   158  	})
   159  }
   160  
   161  func BenchmarkPush(b *testing.B) {
   162  	nums := rand.Perm(b.N)
   163  	mh := New[int](b.N)
   164  	b.ResetTimer()
   165  	for i := 0; i < b.N; i++ {
   166  		mh.Push(nums[i])
   167  	}
   168  }