github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/internal/persistent/map_test.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package persistent
     6  
     7  import (
     8  	"fmt"
     9  	"math/rand"
    10  	"reflect"
    11  	"sync/atomic"
    12  	"testing"
    13  )
    14  
    15  type mapEntry struct {
    16  	key   int
    17  	value int
    18  }
    19  
    20  type validatedMap struct {
    21  	impl     *Map
    22  	expected map[int]int      // current key-value mapping.
    23  	deleted  map[mapEntry]int // maps deleted entries to their clock time of last deletion
    24  	seen     map[mapEntry]int // maps seen entries to their clock time of last insertion
    25  	clock    int
    26  }
    27  
    28  func TestSimpleMap(t *testing.T) {
    29  	deletedEntries := make(map[mapEntry]int)
    30  	seenEntries := make(map[mapEntry]int)
    31  
    32  	m1 := &validatedMap{
    33  		impl: NewMap(func(a, b interface{}) bool {
    34  			return a.(int) < b.(int)
    35  		}),
    36  		expected: make(map[int]int),
    37  		deleted:  deletedEntries,
    38  		seen:     seenEntries,
    39  	}
    40  
    41  	m3 := m1.clone()
    42  	validateRef(t, m1, m3)
    43  	m3.set(t, 8, 8)
    44  	validateRef(t, m1, m3)
    45  	m3.destroy()
    46  
    47  	assertSameMap(t, entrySet(deletedEntries), map[mapEntry]struct{}{
    48  		{key: 8, value: 8}: {},
    49  	})
    50  
    51  	validateRef(t, m1)
    52  	m1.set(t, 1, 1)
    53  	validateRef(t, m1)
    54  	m1.set(t, 2, 2)
    55  	validateRef(t, m1)
    56  	m1.set(t, 3, 3)
    57  	validateRef(t, m1)
    58  	m1.remove(t, 2)
    59  	validateRef(t, m1)
    60  	m1.set(t, 6, 6)
    61  	validateRef(t, m1)
    62  
    63  	assertSameMap(t, entrySet(deletedEntries), map[mapEntry]struct{}{
    64  		{key: 2, value: 2}: {},
    65  		{key: 8, value: 8}: {},
    66  	})
    67  
    68  	m2 := m1.clone()
    69  	validateRef(t, m1, m2)
    70  	m1.set(t, 6, 60)
    71  	validateRef(t, m1, m2)
    72  	m1.remove(t, 1)
    73  	validateRef(t, m1, m2)
    74  
    75  	gotAllocs := int(testing.AllocsPerRun(10, func() {
    76  		m1.impl.Delete(100)
    77  		m1.impl.Delete(1)
    78  	}))
    79  	wantAllocs := 0
    80  	if gotAllocs != wantAllocs {
    81  		t.Errorf("wanted %d allocs, got %d", wantAllocs, gotAllocs)
    82  	}
    83  
    84  	for i := 10; i < 14; i++ {
    85  		m1.set(t, i, i)
    86  		validateRef(t, m1, m2)
    87  	}
    88  
    89  	m1.set(t, 10, 100)
    90  	validateRef(t, m1, m2)
    91  
    92  	m1.remove(t, 12)
    93  	validateRef(t, m1, m2)
    94  
    95  	m2.set(t, 4, 4)
    96  	validateRef(t, m1, m2)
    97  	m2.set(t, 5, 5)
    98  	validateRef(t, m1, m2)
    99  
   100  	m1.destroy()
   101  
   102  	assertSameMap(t, entrySet(deletedEntries), map[mapEntry]struct{}{
   103  		{key: 2, value: 2}:    {},
   104  		{key: 6, value: 60}:   {},
   105  		{key: 8, value: 8}:    {},
   106  		{key: 10, value: 10}:  {},
   107  		{key: 10, value: 100}: {},
   108  		{key: 11, value: 11}:  {},
   109  		{key: 12, value: 12}:  {},
   110  		{key: 13, value: 13}:  {},
   111  	})
   112  
   113  	m2.set(t, 7, 7)
   114  	validateRef(t, m2)
   115  
   116  	m2.destroy()
   117  
   118  	assertSameMap(t, entrySet(seenEntries), entrySet(deletedEntries))
   119  }
   120  
   121  func TestRandomMap(t *testing.T) {
   122  	deletedEntries := make(map[mapEntry]int)
   123  	seenEntries := make(map[mapEntry]int)
   124  
   125  	m := &validatedMap{
   126  		impl: NewMap(func(a, b interface{}) bool {
   127  			return a.(int) < b.(int)
   128  		}),
   129  		expected: make(map[int]int),
   130  		deleted:  deletedEntries,
   131  		seen:     seenEntries,
   132  	}
   133  
   134  	keys := make([]int, 0, 1000)
   135  	for i := 0; i < 1000; i++ {
   136  		key := rand.Intn(10000)
   137  		m.set(t, key, key)
   138  		keys = append(keys, key)
   139  
   140  		if i%10 == 1 {
   141  			index := rand.Intn(len(keys))
   142  			last := len(keys) - 1
   143  			key = keys[index]
   144  			keys[index], keys[last] = keys[last], keys[index]
   145  			keys = keys[:last]
   146  
   147  			m.remove(t, key)
   148  		}
   149  	}
   150  
   151  	m.destroy()
   152  	assertSameMap(t, entrySet(seenEntries), entrySet(deletedEntries))
   153  }
   154  
   155  func entrySet(m map[mapEntry]int) map[mapEntry]struct{} {
   156  	set := make(map[mapEntry]struct{})
   157  	for k := range m {
   158  		set[k] = struct{}{}
   159  	}
   160  	return set
   161  }
   162  
   163  func TestUpdate(t *testing.T) {
   164  	deletedEntries := make(map[mapEntry]int)
   165  	seenEntries := make(map[mapEntry]int)
   166  
   167  	m1 := &validatedMap{
   168  		impl: NewMap(func(a, b interface{}) bool {
   169  			return a.(int) < b.(int)
   170  		}),
   171  		expected: make(map[int]int),
   172  		deleted:  deletedEntries,
   173  		seen:     seenEntries,
   174  	}
   175  	m2 := m1.clone()
   176  
   177  	m1.set(t, 1, 1)
   178  	m1.set(t, 2, 2)
   179  	m2.set(t, 2, 20)
   180  	m2.set(t, 3, 3)
   181  	m1.setAll(t, m2)
   182  
   183  	m1.destroy()
   184  	m2.destroy()
   185  	assertSameMap(t, entrySet(seenEntries), entrySet(deletedEntries))
   186  }
   187  
   188  func validateRef(t *testing.T, maps ...*validatedMap) {
   189  	t.Helper()
   190  
   191  	actualCountByEntry := make(map[mapEntry]int32)
   192  	nodesByEntry := make(map[mapEntry]map[*mapNode]struct{})
   193  	expectedCountByEntry := make(map[mapEntry]int32)
   194  	for i, m := range maps {
   195  		dfsRef(m.impl.root, actualCountByEntry, nodesByEntry)
   196  		dumpMap(t, fmt.Sprintf("%d:", i), m.impl.root)
   197  	}
   198  	for entry, nodes := range nodesByEntry {
   199  		expectedCountByEntry[entry] = int32(len(nodes))
   200  	}
   201  	assertSameMap(t, expectedCountByEntry, actualCountByEntry)
   202  }
   203  
   204  func dfsRef(node *mapNode, countByEntry map[mapEntry]int32, nodesByEntry map[mapEntry]map[*mapNode]struct{}) {
   205  	if node == nil {
   206  		return
   207  	}
   208  
   209  	entry := mapEntry{key: node.key.(int), value: node.value.value.(int)}
   210  	countByEntry[entry] = atomic.LoadInt32(&node.value.refCount)
   211  
   212  	nodes, ok := nodesByEntry[entry]
   213  	if !ok {
   214  		nodes = make(map[*mapNode]struct{})
   215  		nodesByEntry[entry] = nodes
   216  	}
   217  	nodes[node] = struct{}{}
   218  
   219  	dfsRef(node.left, countByEntry, nodesByEntry)
   220  	dfsRef(node.right, countByEntry, nodesByEntry)
   221  }
   222  
   223  func dumpMap(t *testing.T, prefix string, n *mapNode) {
   224  	if n == nil {
   225  		t.Logf("%s nil", prefix)
   226  		return
   227  	}
   228  	t.Logf("%s {key: %v, value: %v (ref: %v), ref: %v, weight: %v}", prefix, n.key, n.value.value, n.value.refCount, n.refCount, n.weight)
   229  	dumpMap(t, prefix+"l", n.left)
   230  	dumpMap(t, prefix+"r", n.right)
   231  }
   232  
   233  func (vm *validatedMap) validate(t *testing.T) {
   234  	t.Helper()
   235  
   236  	validateNode(t, vm.impl.root, vm.impl.less)
   237  
   238  	// Note: this validation may not make sense if maps were constructed using
   239  	// SetAll operations. If this proves to be problematic, remove the clock,
   240  	// deleted, and seen fields.
   241  	for key, value := range vm.expected {
   242  		entry := mapEntry{key: key, value: value}
   243  		if deleteAt := vm.deleted[entry]; deleteAt > vm.seen[entry] {
   244  			t.Fatalf("entry is deleted prematurely, key: %d, value: %d", key, value)
   245  		}
   246  	}
   247  
   248  	actualMap := make(map[int]int, len(vm.expected))
   249  	vm.impl.Range(func(key, value interface{}) {
   250  		if other, ok := actualMap[key.(int)]; ok {
   251  			t.Fatalf("key is present twice, key: %d, first value: %d, second value: %d", key, value, other)
   252  		}
   253  		actualMap[key.(int)] = value.(int)
   254  	})
   255  
   256  	assertSameMap(t, actualMap, vm.expected)
   257  }
   258  
   259  func validateNode(t *testing.T, node *mapNode, less func(a, b interface{}) bool) {
   260  	if node == nil {
   261  		return
   262  	}
   263  
   264  	if node.left != nil {
   265  		if less(node.key, node.left.key) {
   266  			t.Fatalf("left child has larger key: %v vs %v", node.left.key, node.key)
   267  		}
   268  		if node.left.weight > node.weight {
   269  			t.Fatalf("left child has larger weight: %v vs %v", node.left.weight, node.weight)
   270  		}
   271  	}
   272  
   273  	if node.right != nil {
   274  		if less(node.right.key, node.key) {
   275  			t.Fatalf("right child has smaller key: %v vs %v", node.right.key, node.key)
   276  		}
   277  		if node.right.weight > node.weight {
   278  			t.Fatalf("right child has larger weight: %v vs %v", node.right.weight, node.weight)
   279  		}
   280  	}
   281  
   282  	validateNode(t, node.left, less)
   283  	validateNode(t, node.right, less)
   284  }
   285  
   286  func (vm *validatedMap) setAll(t *testing.T, other *validatedMap) {
   287  	vm.impl.SetAll(other.impl)
   288  
   289  	// Note: this is buggy because we are not updating vm.clock, vm.deleted, or
   290  	// vm.seen.
   291  	for key, value := range other.expected {
   292  		vm.expected[key] = value
   293  	}
   294  	vm.validate(t)
   295  }
   296  
   297  func (vm *validatedMap) set(t *testing.T, key, value int) {
   298  	entry := mapEntry{key: key, value: value}
   299  
   300  	vm.clock++
   301  	vm.seen[entry] = vm.clock
   302  
   303  	vm.impl.Set(key, value, func(deletedKey, deletedValue interface{}) {
   304  		if deletedKey != key || deletedValue != value {
   305  			t.Fatalf("unexpected passed in deleted entry: %v/%v, expected: %v/%v", deletedKey, deletedValue, key, value)
   306  		}
   307  		// Not safe if closure shared between two validatedMaps.
   308  		vm.deleted[entry] = vm.clock
   309  	})
   310  	vm.expected[key] = value
   311  	vm.validate(t)
   312  
   313  	gotValue, ok := vm.impl.Get(key)
   314  	if !ok || gotValue != value {
   315  		t.Fatalf("unexpected get result after insertion, key: %v, expected: %v, got: %v (%v)", key, value, gotValue, ok)
   316  	}
   317  }
   318  
   319  func (vm *validatedMap) remove(t *testing.T, key int) {
   320  	vm.clock++
   321  	vm.impl.Delete(key)
   322  	delete(vm.expected, key)
   323  	vm.validate(t)
   324  
   325  	gotValue, ok := vm.impl.Get(key)
   326  	if ok {
   327  		t.Fatalf("unexpected get result after removal, key: %v, got: %v", key, gotValue)
   328  	}
   329  }
   330  
   331  func (vm *validatedMap) clone() *validatedMap {
   332  	expected := make(map[int]int, len(vm.expected))
   333  	for key, value := range vm.expected {
   334  		expected[key] = value
   335  	}
   336  
   337  	return &validatedMap{
   338  		impl:     vm.impl.Clone(),
   339  		expected: expected,
   340  		deleted:  vm.deleted,
   341  		seen:     vm.seen,
   342  	}
   343  }
   344  
   345  func (vm *validatedMap) destroy() {
   346  	vm.impl.Destroy()
   347  }
   348  
   349  func assertSameMap(t *testing.T, map1, map2 interface{}) {
   350  	t.Helper()
   351  
   352  	if !reflect.DeepEqual(map1, map2) {
   353  		t.Fatalf("different maps:\n%v\nvs\n%v", map1, map2)
   354  	}
   355  }