gonum.org/v1/gonum@v0.14.0/stat/card/card_test.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package card
     6  
     7  import (
     8  	"encoding"
     9  	"fmt"
    10  	"hash"
    11  	"hash/fnv"
    12  	"io"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  
    18  	"golang.org/x/exp/rand"
    19  
    20  	"gonum.org/v1/gonum/floats/scalar"
    21  )
    22  
    23  // exact is an exact cardinality accumulator.
    24  type exact map[string]struct{}
    25  
    26  func (e exact) Write(b []byte) (int, error) {
    27  	if _, exists := e[string(b)]; exists {
    28  		return len(b), nil
    29  	}
    30  	e[string(b)] = struct{}{}
    31  	return len(b), nil
    32  }
    33  
    34  func (e exact) Count() float64 {
    35  	return float64(len(e))
    36  }
    37  
    38  type counter interface {
    39  	io.Writer
    40  	Count() float64
    41  }
    42  
    43  var counterTests = []struct {
    44  	name    string
    45  	count   float64
    46  	counter func() counter
    47  	tol     float64
    48  }{
    49  	{name: "exact-1e5", count: 1e5, counter: func() counter { return make(exact) }, tol: 0},
    50  
    51  	{name: "HyperLogLog32-0-10-FNV-1a", count: 0, counter: func() counter { return mustCounter(NewHyperLogLog32(10, fnv.New32a())) }, tol: 0},
    52  	{name: "HyperLogLog64-0-10-FNV-1a", count: 0, counter: func() counter { return mustCounter(NewHyperLogLog64(10, fnv.New64a())) }, tol: 0},
    53  	{name: "HyperLogLog32-10-14-FNV-1a", count: 10, counter: func() counter { return mustCounter(NewHyperLogLog32(14, fnv.New32a())) }, tol: 0.0005},
    54  	{name: "HyperLogLog32-1e3-4-FNV-1a", count: 1e3, counter: func() counter { return mustCounter(NewHyperLogLog32(4, fnv.New32a())) }, tol: 0.1},
    55  	{name: "HyperLogLog32-1e4-6-FNV-1a", count: 1e4, counter: func() counter { return mustCounter(NewHyperLogLog32(6, fnv.New32a())) }, tol: 0.06},
    56  	{name: "HyperLogLog32-1e7-8-FNV-1a", count: 1e7, counter: func() counter { return mustCounter(NewHyperLogLog32(8, fnv.New32a())) }, tol: 0.03},
    57  	{name: "HyperLogLog64-1e7-8-FNV-1a", count: 1e7, counter: func() counter { return mustCounter(NewHyperLogLog64(8, fnv.New64a())) }, tol: 0.07},
    58  	{name: "HyperLogLog32-1e7-10-FNV-1a", count: 1e7, counter: func() counter { return mustCounter(NewHyperLogLog32(10, fnv.New32a())) }, tol: 0.06},
    59  	{name: "HyperLogLog64-1e7-10-FNV-1a", count: 1e7, counter: func() counter { return mustCounter(NewHyperLogLog64(10, fnv.New64a())) }, tol: 0.02},
    60  	{name: "HyperLogLog32-1e7-14-FNV-1a", count: 1e7, counter: func() counter { return mustCounter(NewHyperLogLog32(14, fnv.New32a())) }, tol: 0.005},
    61  	{name: "HyperLogLog64-1e7-14-FNV-1a", count: 1e7, counter: func() counter { return mustCounter(NewHyperLogLog64(14, fnv.New64a())) }, tol: 0.002},
    62  	{name: "HyperLogLog32-1e7-16-FNV-1a", count: 1e7, counter: func() counter { return mustCounter(NewHyperLogLog32(16, fnv.New32a())) }, tol: 0.005},
    63  	{name: "HyperLogLog64-1e7-16-FNV-1a", count: 1e7, counter: func() counter { return mustCounter(NewHyperLogLog64(16, fnv.New64a())) }, tol: 0.002},
    64  	{name: "HyperLogLog64-1e7-20-FNV-1a", count: 1e7, counter: func() counter { return mustCounter(NewHyperLogLog64(20, fnv.New64a())) }, tol: 0.001},
    65  	{name: "HyperLogLog64-1e3-20-FNV-1a", count: 1e3, counter: func() counter { return mustCounter(NewHyperLogLog64(20, fnv.New64a())) }, tol: 0.001},
    66  }
    67  
    68  func mustCounter(c counter, err error) counter {
    69  	if err != nil {
    70  		panic(fmt.Sprintf("bad test: %v", err))
    71  	}
    72  	return c
    73  }
    74  
    75  func TestCounters(t *testing.T) {
    76  	t.Parallel()
    77  
    78  	for _, test := range counterTests {
    79  		test := test
    80  		t.Run(test.name, func(t *testing.T) {
    81  			t.Parallel()
    82  
    83  			rnd := rand.New(rand.NewSource(1))
    84  			var dst []byte
    85  			c := test.counter()
    86  			for i := 0; i < int(test.count); i++ {
    87  				dst = strconv.AppendUint(dst[:0], rnd.Uint64(), 16)
    88  				dst = append(dst, '-')
    89  				dst = strconv.AppendUint(dst, uint64(i), 16)
    90  				n, err := c.Write(dst)
    91  				if n != len(dst) {
    92  					t.Errorf("unexpected number of bytes written for %s: got:%d want:%d",
    93  						test.name, n, len(dst))
    94  					break
    95  				}
    96  				if err != nil {
    97  					t.Errorf("unexpected error for %s: %v", test.name, err)
    98  					break
    99  				}
   100  			}
   101  
   102  			if got := c.Count(); !scalar.EqualWithinRel(got, test.count, test.tol) {
   103  				t.Errorf("unexpected count for %s: got:%.0f want:%.0f", test.name, got, test.count)
   104  			}
   105  		})
   106  	}
   107  }
   108  
   109  func TestUnion(t *testing.T) {
   110  	t.Parallel()
   111  
   112  	for _, test := range counterTests {
   113  		if strings.HasPrefix(test.name, "exact") {
   114  			continue
   115  		}
   116  		test := test
   117  		t.Run(test.name, func(t *testing.T) {
   118  			t.Parallel()
   119  			rnd := rand.New(rand.NewSource(1))
   120  			var dst []byte
   121  			var cs [2]counter
   122  			for j := range cs {
   123  				cs[j] = test.counter()
   124  				for i := 0; i < int(test.count); i++ {
   125  					dst = strconv.AppendUint(dst[:0], rnd.Uint64(), 16)
   126  					dst = append(dst, '-')
   127  					dst = strconv.AppendUint(dst, uint64(i), 16)
   128  					n, err := cs[j].Write(dst)
   129  					if n != len(dst) {
   130  						t.Errorf("unexpected number of bytes written for %s: got:%d want:%d",
   131  							test.name, n, len(dst))
   132  						break
   133  					}
   134  					if err != nil {
   135  						t.Errorf("unexpected error for %s: %v", test.name, err)
   136  						break
   137  					}
   138  				}
   139  			}
   140  
   141  			u := test.counter()
   142  			var err error
   143  			switch u := u.(type) {
   144  			case *HyperLogLog32:
   145  				err = u.Union(cs[0].(*HyperLogLog32), cs[1].(*HyperLogLog32))
   146  			case *HyperLogLog64:
   147  				err = u.Union(cs[0].(*HyperLogLog64), cs[1].(*HyperLogLog64))
   148  			}
   149  			if err != nil {
   150  				t.Errorf("unexpected error from Union call: %v", err)
   151  			}
   152  			if got := u.Count(); !scalar.EqualWithinRel(got, 2*test.count, 2*test.tol) {
   153  				t.Errorf("unexpected count for %s: got:%.0f want:%.0f", test.name, got, 2*test.count)
   154  			}
   155  		})
   156  	}
   157  }
   158  
   159  type resetCounter interface {
   160  	counter
   161  	Reset()
   162  }
   163  
   164  var counterResetTests = []struct {
   165  	name         string
   166  	count        int
   167  	resetCounter func() resetCounter
   168  }{
   169  	{name: "HyperLogLog32-1e3-4-FNV-1a", count: 1e3, resetCounter: func() resetCounter { return mustResetCounter(NewHyperLogLog32(4, fnv.New32a())) }},
   170  	{name: "HyperLogLog64-1e3-4-FNV-1a", count: 1e3, resetCounter: func() resetCounter { return mustResetCounter(NewHyperLogLog64(4, fnv.New64a())) }},
   171  	{name: "HyperLogLog32-1e4-6-FNV-1a", count: 1e4, resetCounter: func() resetCounter { return mustResetCounter(NewHyperLogLog32(6, fnv.New32a())) }},
   172  	{name: "HyperLogLog64-1e4-6-FNV-1a", count: 1e4, resetCounter: func() resetCounter { return mustResetCounter(NewHyperLogLog64(6, fnv.New64a())) }},
   173  }
   174  
   175  func mustResetCounter(c resetCounter, err error) resetCounter {
   176  	if err != nil {
   177  		panic(fmt.Sprintf("bad test: %v", err))
   178  	}
   179  	return c
   180  }
   181  
   182  func TestResetCounters(t *testing.T) {
   183  	t.Parallel()
   184  
   185  	var dst []byte
   186  	for _, test := range counterResetTests {
   187  		c := test.resetCounter()
   188  		var counts [2]float64
   189  		for k := range counts {
   190  			rnd := rand.New(rand.NewSource(1))
   191  			for i := 0; i < test.count; i++ {
   192  				dst = strconv.AppendUint(dst[:0], rnd.Uint64(), 16)
   193  				dst = append(dst, '-')
   194  				dst = strconv.AppendUint(dst, uint64(i), 16)
   195  				n, err := c.Write(dst)
   196  				if n != len(dst) {
   197  					t.Errorf("unexpected number of bytes written for %s: got:%d want:%d",
   198  						test.name, n, len(dst))
   199  					break
   200  				}
   201  				if err != nil {
   202  					t.Errorf("unexpected error for %s: %v", test.name, err)
   203  					break
   204  				}
   205  			}
   206  			counts[k] = c.Count()
   207  			c.Reset()
   208  		}
   209  
   210  		if counts[0] != counts[1] {
   211  			t.Errorf("unexpected counts for %s after reset: got:%.0f", test.name, counts)
   212  		}
   213  	}
   214  }
   215  
   216  type counterEncoder interface {
   217  	counter
   218  	encoding.BinaryMarshaler
   219  	encoding.BinaryUnmarshaler
   220  }
   221  
   222  var counterEncoderTests = []struct {
   223  	name           string
   224  	count          int
   225  	src, dst, zdst func() counterEncoder
   226  }{
   227  	{
   228  		name: "HyperLogLog32-4-4-FNV-1a", count: 1e3,
   229  		src:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
   230  		dst:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
   231  		zdst: func() counterEncoder { return &HyperLogLog32{} },
   232  	},
   233  	{
   234  		name: "HyperLogLog32-4-8-FNV-1a", count: 1e3,
   235  		src:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
   236  		dst:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) },
   237  		zdst: func() counterEncoder { return &HyperLogLog32{} },
   238  	},
   239  	{
   240  		name: "HyperLogLog32-8-4-FNV-1a", count: 1e3,
   241  		src:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) },
   242  		dst:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
   243  		zdst: func() counterEncoder { return &HyperLogLog32{} },
   244  	},
   245  	{
   246  		name: "HyperLogLog64-4-4-FNV-1a", count: 1e3,
   247  		src:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
   248  		dst:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
   249  		zdst: func() counterEncoder { return &HyperLogLog64{} },
   250  	},
   251  	{
   252  		name: "HyperLogLog64-4-8-FNV-1a", count: 1e3,
   253  		src:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
   254  		dst:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) },
   255  		zdst: func() counterEncoder { return &HyperLogLog64{} },
   256  	},
   257  	{
   258  		name: "HyperLogLog64-8-4-FNV-1a", count: 1e3,
   259  		src:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) },
   260  		dst:  func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
   261  		zdst: func() counterEncoder { return &HyperLogLog64{} },
   262  	},
   263  }
   264  
   265  func mustCounterEncoder(c counterEncoder, err error) counterEncoder {
   266  	if err != nil {
   267  		panic(fmt.Sprintf("bad test: %v", err))
   268  	}
   269  	return c
   270  }
   271  
   272  func TestBinaryEncoding(t *testing.T) {
   273  	t.Parallel()
   274  
   275  	RegisterHash(fnv.New32a)
   276  	RegisterHash(fnv.New64a)
   277  	defer func() {
   278  		hashes = sync.Map{}
   279  	}()
   280  	for _, test := range counterEncoderTests {
   281  		rnd := rand.New(rand.NewSource(1))
   282  		src := test.src()
   283  		for i := 0; i < test.count; i++ {
   284  			buf := strconv.AppendUint(nil, rnd.Uint64(), 16)
   285  			buf = append(buf, '-')
   286  			buf = strconv.AppendUint(buf, uint64(i), 16)
   287  			n, err := src.Write(buf)
   288  			if n != len(buf) {
   289  				t.Errorf("unexpected number of bytes written for %s: got:%d want:%d",
   290  					test.name, n, len(buf))
   291  				break
   292  			}
   293  			if err != nil {
   294  				t.Errorf("unexpected error for %s: %v", test.name, err)
   295  				break
   296  			}
   297  		}
   298  
   299  		buf, err := src.MarshalBinary()
   300  		if err != nil {
   301  			t.Errorf("unexpected error marshaling binary for %s: %v", test.name, err)
   302  			continue
   303  		}
   304  		dst := test.dst()
   305  		err = dst.UnmarshalBinary(buf)
   306  		if err != nil {
   307  			t.Errorf("unexpected error unmarshaling binary for %s: %v", test.name, err)
   308  			continue
   309  		}
   310  		zdst := test.zdst()
   311  		err = zdst.UnmarshalBinary(buf)
   312  		if err != nil {
   313  			t.Errorf("unexpected error unmarshaling binary into zero receiver for %s: %v", test.name, err)
   314  			continue
   315  		}
   316  		gotSrc := src.Count()
   317  		gotDst := dst.Count()
   318  		gotZdst := zdst.Count()
   319  
   320  		if gotSrc != gotDst {
   321  			t.Errorf("unexpected count for %s: got:%.0f want:%.0f", test.name, gotDst, gotSrc)
   322  		}
   323  		if gotSrc != gotZdst {
   324  			t.Errorf("unexpected count for %s into zero receiver: got:%.0f want:%.0f", test.name, gotZdst, gotSrc)
   325  		}
   326  	}
   327  }
   328  
   329  var invalidRegisterTests = []struct {
   330  	fn     interface{}
   331  	panics bool
   332  }{
   333  	{fn: int(0), panics: true},
   334  	{fn: func() {}, panics: true},
   335  	{fn: func(int) {}, panics: true},
   336  	{fn: func() int { return 0 }, panics: true},
   337  	{fn: func() hash.Hash { return fnv.New32a() }, panics: true},
   338  	{fn: func() hash.Hash32 { return fnv.New32a() }, panics: false},
   339  	{fn: func() hash.Hash { return fnv.New64a() }, panics: true},
   340  	{fn: func() hash.Hash64 { return fnv.New64a() }, panics: false},
   341  }
   342  
   343  func TestRegisterInvalid(t *testing.T) {
   344  	t.Parallel()
   345  
   346  	for _, test := range invalidRegisterTests {
   347  		var r interface{}
   348  		func() {
   349  			defer func() {
   350  				r = recover()
   351  			}()
   352  			RegisterHash(test.fn)
   353  		}()
   354  		panicked := r != nil
   355  		if panicked != test.panics {
   356  			if panicked {
   357  				t.Errorf("unexpected panic for %T", test.fn)
   358  			} else {
   359  				t.Errorf("expected panic for %T", test.fn)
   360  			}
   361  		}
   362  	}
   363  }
   364  
   365  var rhoQTests = []struct {
   366  	bits uint
   367  	q    uint8
   368  	want uint8
   369  }{
   370  	{bits: 0xff, q: 8, want: 1},
   371  	{bits: 0xfe, q: 8, want: 1},
   372  	{bits: 0x0f, q: 8, want: 5},
   373  	{bits: 0x1f, q: 8, want: 4},
   374  	{bits: 0x00, q: 8, want: 9},
   375  }
   376  
   377  func TestRhoQ(t *testing.T) {
   378  	t.Parallel()
   379  
   380  	for _, test := range rhoQTests {
   381  		got := rho32q(uint32(test.bits), test.q)
   382  		if got != test.want {
   383  			t.Errorf("unexpected rho32q for %0*b: got:%d want:%d", test.q, test.bits, got, test.want)
   384  		}
   385  		got = rho64q(uint64(test.bits), test.q)
   386  		if got != test.want {
   387  			t.Errorf("unexpected rho64q for %0*b: got:%d want:%d", test.q, test.bits, got, test.want)
   388  		}
   389  	}
   390  }
   391  
   392  var counterBenchmarks = []struct {
   393  	name    string
   394  	count   int
   395  	counter func() counter
   396  }{
   397  	{name: "exact-1e6", count: 1e6, counter: func() counter { return make(exact) }},
   398  	{name: "HyperLogLog32-1e6-8-FNV-1a", count: 1e6, counter: func() counter { return mustCounter(NewHyperLogLog32(8, fnv.New32a())) }},
   399  	{name: "HyperLogLog64-1e6-8-FNV-1a", count: 1e6, counter: func() counter { return mustCounter(NewHyperLogLog64(8, fnv.New64a())) }},
   400  	{name: "HyperLogLog32-1e6-16-FNV-1a", count: 1e6, counter: func() counter { return mustCounter(NewHyperLogLog32(16, fnv.New32a())) }},
   401  	{name: "HyperLogLog64-1e6-16-FNV-1a", count: 1e6, counter: func() counter { return mustCounter(NewHyperLogLog64(16, fnv.New64a())) }},
   402  }
   403  
   404  func BenchmarkCounters(b *testing.B) {
   405  	for _, bench := range counterBenchmarks {
   406  		c := bench.counter()
   407  		rnd := rand.New(rand.NewSource(1))
   408  		var dst []byte
   409  		b.Run(bench.name, func(b *testing.B) {
   410  			for i := 0; i < b.N; i++ {
   411  				for j := 0; j < bench.count; j++ {
   412  					dst = strconv.AppendUint(dst[:0], rnd.Uint64(), 16)
   413  					dst = append(dst, '-')
   414  					dst = strconv.AppendUint(dst, uint64(j), 16)
   415  					_, _ = c.Write(dst)
   416  				}
   417  			}
   418  			_ = c.Count()
   419  		})
   420  	}
   421  }