github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/go/callgraph/vta/internal/trie/op_test.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package trie_test
     6  
     7  import (
     8  	"fmt"
     9  	"math/rand"
    10  	"reflect"
    11  	"testing"
    12  	"time"
    13  
    14  	"golang.org/x/tools/go/callgraph/vta/internal/trie"
    15  )
    16  
    17  // This file tests trie.Map by cross checking operations on a collection of
    18  // trie.Map's against a collection of map[uint64]interface{}. This includes
    19  // both limited fuzz testing for correctness and benchmarking.
    20  
    21  // mapCollection is effectively a []map[uint64]interface{}.
    22  // These support operations being applied to the i'th maps.
    23  type mapCollection interface {
    24  	Elements() []map[uint64]interface{}
    25  
    26  	DeepEqual(l, r int) bool
    27  	Lookup(id int, k uint64) (interface{}, bool)
    28  
    29  	Insert(id int, k uint64, v interface{})
    30  	Update(id int, k uint64, v interface{})
    31  	Remove(id int, k uint64)
    32  	Intersect(l int, r int)
    33  	Merge(l int, r int)
    34  	Clear(id int)
    35  
    36  	Average(l int, r int)
    37  	Assign(l int, r int)
    38  }
    39  
    40  // opCode of an operation.
    41  type opCode int
    42  
    43  const (
    44  	deepEqualsOp opCode = iota
    45  	lookupOp
    46  	insert
    47  	update
    48  	remove
    49  	merge
    50  	intersect
    51  	clear
    52  	takeAverage
    53  	assign
    54  )
    55  
    56  func (op opCode) String() string {
    57  	switch op {
    58  	case deepEqualsOp:
    59  		return "DE"
    60  	case lookupOp:
    61  		return "LO"
    62  	case insert:
    63  		return "IN"
    64  	case update:
    65  		return "UP"
    66  	case remove:
    67  		return "RE"
    68  	case merge:
    69  		return "ME"
    70  	case intersect:
    71  		return "IT"
    72  	case clear:
    73  		return "CL"
    74  	case takeAverage:
    75  		return "AV"
    76  	case assign:
    77  		return "AS"
    78  	default:
    79  		return "??"
    80  	}
    81  }
    82  
    83  // A mapCollection backed by MutMaps.
    84  type trieCollection struct {
    85  	b     *trie.Builder
    86  	tries []trie.MutMap
    87  }
    88  
    89  func (c *trieCollection) Elements() []map[uint64]interface{} {
    90  	var maps []map[uint64]interface{}
    91  	for _, m := range c.tries {
    92  		maps = append(maps, trie.Elems(m.M))
    93  	}
    94  	return maps
    95  }
    96  func (c *trieCollection) Eq(id int, m map[uint64]interface{}) bool {
    97  	elems := trie.Elems(c.tries[id].M)
    98  	return !reflect.DeepEqual(elems, m)
    99  }
   100  
   101  func (c *trieCollection) Lookup(id int, k uint64) (interface{}, bool) {
   102  	return c.tries[id].M.Lookup(k)
   103  }
   104  func (c *trieCollection) DeepEqual(l, r int) bool {
   105  	return c.tries[l].M.DeepEqual(c.tries[r].M)
   106  }
   107  
   108  func (c *trieCollection) Add() {
   109  	c.tries = append(c.tries, c.b.MutEmpty())
   110  }
   111  
   112  func (c *trieCollection) Insert(id int, k uint64, v interface{}) {
   113  	c.tries[id].Insert(k, v)
   114  }
   115  
   116  func (c *trieCollection) Update(id int, k uint64, v interface{}) {
   117  	c.tries[id].Update(k, v)
   118  }
   119  
   120  func (c *trieCollection) Remove(id int, k uint64) {
   121  	c.tries[id].Remove(k)
   122  }
   123  
   124  func (c *trieCollection) Intersect(l int, r int) {
   125  	c.tries[l].Intersect(c.tries[r].M)
   126  }
   127  
   128  func (c *trieCollection) Merge(l int, r int) {
   129  	c.tries[l].Merge(c.tries[r].M)
   130  }
   131  
   132  func (c *trieCollection) Average(l int, r int) {
   133  	c.tries[l].MergeWith(average, c.tries[r].M)
   134  }
   135  
   136  func (c *trieCollection) Clear(id int) {
   137  	c.tries[id] = c.b.MutEmpty()
   138  }
   139  func (c *trieCollection) Assign(l, r int) {
   140  	c.tries[l] = c.tries[r]
   141  }
   142  
   143  func average(x interface{}, y interface{}) interface{} {
   144  	if x, ok := x.(float32); ok {
   145  		if y, ok := y.(float32); ok {
   146  			return (x + y) / 2.0
   147  		}
   148  	}
   149  	return x
   150  }
   151  
   152  type builtinCollection []map[uint64]interface{}
   153  
   154  func (c builtinCollection) Elements() []map[uint64]interface{} {
   155  	return c
   156  }
   157  
   158  func (c builtinCollection) Lookup(id int, k uint64) (interface{}, bool) {
   159  	v, ok := c[id][k]
   160  	return v, ok
   161  }
   162  func (c builtinCollection) DeepEqual(l, r int) bool {
   163  	return reflect.DeepEqual(c[l], c[r])
   164  }
   165  
   166  func (c builtinCollection) Insert(id int, k uint64, v interface{}) {
   167  	if _, ok := c[id][k]; !ok {
   168  		c[id][k] = v
   169  	}
   170  }
   171  
   172  func (c builtinCollection) Update(id int, k uint64, v interface{}) {
   173  	c[id][k] = v
   174  }
   175  
   176  func (c builtinCollection) Remove(id int, k uint64) {
   177  	delete(c[id], k)
   178  }
   179  
   180  func (c builtinCollection) Intersect(l int, r int) {
   181  	result := map[uint64]interface{}{}
   182  	for k, v := range c[l] {
   183  		if _, ok := c[r][k]; ok {
   184  			result[k] = v
   185  		}
   186  	}
   187  	c[l] = result
   188  }
   189  
   190  func (c builtinCollection) Merge(l int, r int) {
   191  	result := map[uint64]interface{}{}
   192  	for k, v := range c[r] {
   193  		result[k] = v
   194  	}
   195  	for k, v := range c[l] {
   196  		result[k] = v
   197  	}
   198  	c[l] = result
   199  }
   200  
   201  func (c builtinCollection) Average(l int, r int) {
   202  	avg := map[uint64]interface{}{}
   203  	for k, lv := range c[l] {
   204  		if rv, ok := c[r][k]; ok {
   205  			avg[k] = average(lv, rv)
   206  		} else {
   207  			avg[k] = lv // add elements just in l
   208  		}
   209  	}
   210  	for k, rv := range c[r] {
   211  		if _, ok := c[l][k]; !ok {
   212  			avg[k] = rv // add elements just in r
   213  		}
   214  	}
   215  	c[l] = avg
   216  }
   217  
   218  func (c builtinCollection) Assign(l, r int) {
   219  	m := map[uint64]interface{}{}
   220  	for k, v := range c[r] {
   221  		m[k] = v
   222  	}
   223  	c[l] = m
   224  }
   225  
   226  func (c builtinCollection) Clear(id int) {
   227  	c[id] = map[uint64]interface{}{}
   228  }
   229  
   230  func newTriesCollection(size int) *trieCollection {
   231  	tc := &trieCollection{
   232  		b:     trie.NewBuilder(),
   233  		tries: make([]trie.MutMap, size),
   234  	}
   235  	for i := 0; i < size; i++ {
   236  		tc.tries[i] = tc.b.MutEmpty()
   237  	}
   238  	return tc
   239  }
   240  
   241  func newMapsCollection(size int) *builtinCollection {
   242  	maps := make(builtinCollection, size)
   243  	for i := 0; i < size; i++ {
   244  		maps[i] = map[uint64]interface{}{}
   245  	}
   246  	return &maps
   247  }
   248  
   249  // operation on a map collection.
   250  type operation struct {
   251  	code opCode
   252  	l, r int
   253  	k    uint64
   254  	v    float32
   255  }
   256  
   257  // Apply the operation to maps.
   258  func (op operation) Apply(maps mapCollection) interface{} {
   259  	type lookupresult struct {
   260  		v  interface{}
   261  		ok bool
   262  	}
   263  	switch op.code {
   264  	case deepEqualsOp:
   265  		return maps.DeepEqual(op.l, op.r)
   266  	case lookupOp:
   267  		v, ok := maps.Lookup(op.l, op.k)
   268  		return lookupresult{v, ok}
   269  	case insert:
   270  		maps.Insert(op.l, op.k, op.v)
   271  	case update:
   272  		maps.Update(op.l, op.k, op.v)
   273  	case remove:
   274  		maps.Remove(op.l, op.k)
   275  	case merge:
   276  		maps.Merge(op.l, op.r)
   277  	case intersect:
   278  		maps.Intersect(op.l, op.r)
   279  	case clear:
   280  		maps.Clear(op.l)
   281  	case takeAverage:
   282  		maps.Average(op.l, op.r)
   283  	case assign:
   284  		maps.Assign(op.l, op.r)
   285  	}
   286  	return nil
   287  }
   288  
   289  // Returns a collection of op codes with dist[op] copies of op.
   290  func distribution(dist map[opCode]int) []opCode {
   291  	var codes []opCode
   292  	for op, n := range dist {
   293  		for i := 0; i < n; i++ {
   294  			codes = append(codes, op)
   295  		}
   296  	}
   297  	return codes
   298  }
   299  
   300  // options for generating a random operation.
   301  type options struct {
   302  	maps   int
   303  	maxKey uint64
   304  	maxVal int
   305  	codes  []opCode
   306  }
   307  
   308  // returns a random operation using r as a source of randomness.
   309  func randOperator(r *rand.Rand, opts options) operation {
   310  	id := func() int { return r.Intn(opts.maps) }
   311  	key := func() uint64 { return r.Uint64() % opts.maxKey }
   312  	val := func() float32 { return float32(r.Intn(opts.maxVal)) }
   313  	switch code := opts.codes[r.Intn(len(opts.codes))]; code {
   314  	case lookupOp, remove:
   315  		return operation{code: code, l: id(), k: key()}
   316  	case insert, update:
   317  		return operation{code: code, l: id(), k: key(), v: val()}
   318  	case deepEqualsOp, merge, intersect, takeAverage, assign:
   319  		return operation{code: code, l: id(), r: id()}
   320  	case clear:
   321  		return operation{code: code, l: id()}
   322  	default:
   323  		panic("Invalid op code")
   324  	}
   325  }
   326  
   327  func randOperators(r *rand.Rand, numops int, opts options) []operation {
   328  	ops := make([]operation, numops)
   329  	for i := 0; i < numops; i++ {
   330  		ops[i] = randOperator(r, opts)
   331  	}
   332  	return ops
   333  }
   334  
   335  // TestOperations applies a series of random operations to collection of
   336  // trie.MutMaps and map[uint64]interface{}. It tests for the maps being equal.
   337  func TestOperations(t *testing.T) {
   338  	seed := time.Now().UnixNano()
   339  	s := rand.NewSource(seed)
   340  	r := rand.New(s)
   341  	t.Log("seed: ", seed)
   342  
   343  	size := 10
   344  	N := 100000
   345  	ops := randOperators(r, N, options{
   346  		maps:   size,
   347  		maxKey: 128,
   348  		maxVal: 100,
   349  		codes: distribution(map[opCode]int{
   350  			deepEqualsOp: 1,
   351  			lookupOp:     10,
   352  			insert:       10,
   353  			update:       10,
   354  			remove:       10,
   355  			merge:        10,
   356  			intersect:    10,
   357  			clear:        2,
   358  			takeAverage:  5,
   359  			assign:       5,
   360  		}),
   361  	})
   362  
   363  	var tries mapCollection = newTriesCollection(size)
   364  	var maps mapCollection = newMapsCollection(size)
   365  	check := func() error {
   366  		if got, want := tries.Elements(), maps.Elements(); !reflect.DeepEqual(got, want) {
   367  			return fmt.Errorf("elements of tries and maps and tries differed. got %v want %v", got, want)
   368  		}
   369  		return nil
   370  	}
   371  
   372  	for i, op := range ops {
   373  		got, want := op.Apply(tries), op.Apply(maps)
   374  		if got != want {
   375  			t.Errorf("op[%d]: (%v).Apply(%v) != (%v).Apply(%v). got %v want %v",
   376  				i, op, tries, op, maps, got, want)
   377  		}
   378  	}
   379  	if err := check(); err != nil {
   380  		t.Errorf("%d operators failed with %s", size, err)
   381  		t.Log("Rerunning with more checking")
   382  		tries, maps = newTriesCollection(size), newMapsCollection(size)
   383  		for i, op := range ops {
   384  			op.Apply(tries)
   385  			op.Apply(maps)
   386  			if err := check(); err != nil {
   387  				t.Fatalf("Failed first on op[%d]=%v: %v", i, op, err)
   388  			}
   389  		}
   390  	}
   391  }
   392  
   393  func run(b *testing.B, opts options, seed int64, mk func(int) mapCollection) {
   394  	r := rand.New(rand.NewSource(seed))
   395  	ops := randOperators(r, b.N, opts)
   396  	maps := mk(opts.maps)
   397  	for _, op := range ops {
   398  		op.Apply(maps)
   399  	}
   400  }
   401  
   402  var standard options = options{
   403  	maps:   10,
   404  	maxKey: 128,
   405  	maxVal: 100,
   406  	codes: distribution(map[opCode]int{
   407  		deepEqualsOp: 1,
   408  		lookupOp:     20,
   409  		insert:       20,
   410  		update:       20,
   411  		remove:       20,
   412  		merge:        10,
   413  		intersect:    10,
   414  		clear:        1,
   415  		takeAverage:  5,
   416  		assign:       20,
   417  	}),
   418  }
   419  
   420  func BenchmarkTrieStandard(b *testing.B) {
   421  	run(b, standard, 123, func(size int) mapCollection {
   422  		return newTriesCollection(size)
   423  	})
   424  }
   425  
   426  func BenchmarkMapsStandard(b *testing.B) {
   427  	run(b, standard, 123, func(size int) mapCollection {
   428  		return newMapsCollection(size)
   429  	})
   430  }
   431  
   432  var smallWide options = options{
   433  	maps:   100,
   434  	maxKey: 100,
   435  	maxVal: 8,
   436  	codes: distribution(map[opCode]int{
   437  		deepEqualsOp: 0,
   438  		lookupOp:     0,
   439  		insert:       30,
   440  		update:       20,
   441  		remove:       0,
   442  		merge:        10,
   443  		intersect:    0,
   444  		clear:        1,
   445  		takeAverage:  0,
   446  		assign:       30,
   447  	}),
   448  }
   449  
   450  func BenchmarkTrieSmallWide(b *testing.B) {
   451  	run(b, smallWide, 456, func(size int) mapCollection {
   452  		return newTriesCollection(size)
   453  	})
   454  }
   455  
   456  func BenchmarkMapsSmallWide(b *testing.B) {
   457  	run(b, smallWide, 456, func(size int) mapCollection {
   458  		return newMapsCollection(size)
   459  	})
   460  }