github.com/dolthub/swiss@v0.2.2-0.20240312182618-f4b2babd2bc1/map_test.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package swiss
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"math/rand"
    21  	"testing"
    22  
    23  	"github.com/stretchr/testify/require"
    24  
    25  	"github.com/stretchr/testify/assert"
    26  )
    27  
    28  func TestSwissMap(t *testing.T) {
    29  	t.Run("strings=0", func(t *testing.T) {
    30  		testSwissMap(t, genStringData(16, 0))
    31  	})
    32  	t.Run("strings=100", func(t *testing.T) {
    33  		testSwissMap(t, genStringData(16, 100))
    34  	})
    35  	t.Run("strings=1000", func(t *testing.T) {
    36  		testSwissMap(t, genStringData(16, 1000))
    37  	})
    38  	t.Run("strings=10_000", func(t *testing.T) {
    39  		testSwissMap(t, genStringData(16, 10_000))
    40  	})
    41  	t.Run("strings=100_000", func(t *testing.T) {
    42  		testSwissMap(t, genStringData(16, 100_000))
    43  	})
    44  	t.Run("uint32=0", func(t *testing.T) {
    45  		testSwissMap(t, genUint32Data(0))
    46  	})
    47  	t.Run("uint32=100", func(t *testing.T) {
    48  		testSwissMap(t, genUint32Data(100))
    49  	})
    50  	t.Run("uint32=1000", func(t *testing.T) {
    51  		testSwissMap(t, genUint32Data(1000))
    52  	})
    53  	t.Run("uint32=10_000", func(t *testing.T) {
    54  		testSwissMap(t, genUint32Data(10_000))
    55  	})
    56  	t.Run("uint32=100_000", func(t *testing.T) {
    57  		testSwissMap(t, genUint32Data(100_000))
    58  	})
    59  	t.Run("string capacity", func(t *testing.T) {
    60  		testSwissMapCapacity(t, func(n int) []string {
    61  			return genStringData(16, n)
    62  		})
    63  	})
    64  	t.Run("uint32 capacity", func(t *testing.T) {
    65  		testSwissMapCapacity(t, genUint32Data)
    66  	})
    67  }
    68  
    69  func testSwissMap[K comparable](t *testing.T, keys []K) {
    70  	// sanity check
    71  	require.Equal(t, len(keys), len(uniq(keys)), keys)
    72  	t.Run("put", func(t *testing.T) {
    73  		testMapPut(t, keys)
    74  	})
    75  	t.Run("has", func(t *testing.T) {
    76  		testMapHas(t, keys)
    77  	})
    78  	t.Run("get", func(t *testing.T) {
    79  		testMapGet(t, keys)
    80  	})
    81  	t.Run("delete", func(t *testing.T) {
    82  		testMapDelete(t, keys)
    83  	})
    84  	t.Run("clear", func(t *testing.T) {
    85  		testMapClear(t, keys)
    86  	})
    87  	t.Run("iter", func(t *testing.T) {
    88  		testMapIter(t, keys)
    89  	})
    90  	t.Run("grow", func(t *testing.T) {
    91  		testMapGrow(t, keys)
    92  	})
    93  	t.Run("probe stats", func(t *testing.T) {
    94  		testProbeStats(t, keys)
    95  	})
    96  }
    97  
    98  func uniq[K comparable](keys []K) []K {
    99  	s := make(map[K]struct{}, len(keys))
   100  	for _, k := range keys {
   101  		s[k] = struct{}{}
   102  	}
   103  	u := make([]K, 0, len(keys))
   104  	for k := range s {
   105  		u = append(u, k)
   106  	}
   107  	return u
   108  }
   109  
   110  func genStringData(size, count int) (keys []string) {
   111  	src := rand.New(rand.NewSource(int64(size * count)))
   112  	letters := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
   113  	r := make([]rune, size*count)
   114  	for i := range r {
   115  		r[i] = letters[src.Intn(len(letters))]
   116  	}
   117  	keys = make([]string, count)
   118  	for i := range keys {
   119  		keys[i] = string(r[:size])
   120  		r = r[size:]
   121  	}
   122  	return
   123  }
   124  
   125  func genUint32Data(count int) (keys []uint32) {
   126  	keys = make([]uint32, count)
   127  	var x uint32
   128  	for i := range keys {
   129  		x += (rand.Uint32() % 128) + 1
   130  		keys[i] = x
   131  	}
   132  	return
   133  }
   134  
   135  func testMapPut[K comparable](t *testing.T, keys []K) {
   136  	m := NewMap[K, int](uint32(len(keys)))
   137  	assert.Equal(t, 0, m.Count())
   138  	for i, key := range keys {
   139  		m.Put(key, i)
   140  	}
   141  	assert.Equal(t, len(keys), m.Count())
   142  	// overwrite
   143  	for i, key := range keys {
   144  		m.Put(key, -i)
   145  	}
   146  	assert.Equal(t, len(keys), m.Count())
   147  	for i, key := range keys {
   148  		act, ok := m.Get(key)
   149  		assert.True(t, ok)
   150  		assert.Equal(t, -i, act)
   151  	}
   152  	assert.Equal(t, len(keys), int(m.resident))
   153  }
   154  
   155  func testMapHas[K comparable](t *testing.T, keys []K) {
   156  	m := NewMap[K, int](uint32(len(keys)))
   157  	for i, key := range keys {
   158  		m.Put(key, i)
   159  	}
   160  	for _, key := range keys {
   161  		ok := m.Has(key)
   162  		assert.True(t, ok)
   163  	}
   164  }
   165  
   166  func testMapGet[K comparable](t *testing.T, keys []K) {
   167  	m := NewMap[K, int](uint32(len(keys)))
   168  	for i, key := range keys {
   169  		m.Put(key, i)
   170  	}
   171  	for i, key := range keys {
   172  		act, ok := m.Get(key)
   173  		assert.True(t, ok)
   174  		assert.Equal(t, i, act)
   175  	}
   176  }
   177  
   178  func testMapDelete[K comparable](t *testing.T, keys []K) {
   179  	m := NewMap[K, int](uint32(len(keys)))
   180  	assert.Equal(t, 0, m.Count())
   181  	for i, key := range keys {
   182  		m.Put(key, i)
   183  	}
   184  	assert.Equal(t, len(keys), m.Count())
   185  	for _, key := range keys {
   186  		m.Delete(key)
   187  		ok := m.Has(key)
   188  		assert.False(t, ok)
   189  	}
   190  	assert.Equal(t, 0, m.Count())
   191  	// put keys back after deleting them
   192  	for i, key := range keys {
   193  		m.Put(key, i)
   194  	}
   195  	assert.Equal(t, len(keys), m.Count())
   196  }
   197  
   198  func testMapClear[K comparable](t *testing.T, keys []K) {
   199  	m := NewMap[K, int](0)
   200  	assert.Equal(t, 0, m.Count())
   201  	for i, key := range keys {
   202  		m.Put(key, i)
   203  	}
   204  	assert.Equal(t, len(keys), m.Count())
   205  	m.Clear()
   206  	assert.Equal(t, 0, m.Count())
   207  	for _, key := range keys {
   208  		ok := m.Has(key)
   209  		assert.False(t, ok)
   210  		_, ok = m.Get(key)
   211  		assert.False(t, ok)
   212  	}
   213  	var calls int
   214  	m.Iter(func(k K, v int) (stop bool) {
   215  		calls++
   216  		return
   217  	})
   218  	assert.Equal(t, 0, calls)
   219  
   220  	// Assert that the map was actually cleared...
   221  	var k K
   222  	for _, g := range m.groups {
   223  		for i := range g.keys {
   224  			assert.Equal(t, k, g.keys[i])
   225  			assert.Equal(t, 0, g.values[i])
   226  		}
   227  	}
   228  }
   229  
   230  func testMapIter[K comparable](t *testing.T, keys []K) {
   231  	m := NewMap[K, int](uint32(len(keys)))
   232  	for i, key := range keys {
   233  		m.Put(key, i)
   234  	}
   235  	visited := make(map[K]uint, len(keys))
   236  	m.Iter(func(k K, v int) (stop bool) {
   237  		visited[k] = 0
   238  		stop = true
   239  		return
   240  	})
   241  	if len(keys) == 0 {
   242  		assert.Equal(t, len(visited), 0)
   243  	} else {
   244  		assert.Equal(t, len(visited), 1)
   245  	}
   246  	for _, k := range keys {
   247  		visited[k] = 0
   248  	}
   249  	m.Iter(func(k K, v int) (stop bool) {
   250  		visited[k]++
   251  		return
   252  	})
   253  	for _, c := range visited {
   254  		assert.Equal(t, c, uint(1))
   255  	}
   256  	// mutate on iter
   257  	m.Iter(func(k K, v int) (stop bool) {
   258  		m.Put(k, -v)
   259  		return
   260  	})
   261  	for i, key := range keys {
   262  		act, ok := m.Get(key)
   263  		assert.True(t, ok)
   264  		assert.Equal(t, -i, act)
   265  	}
   266  }
   267  
   268  func testMapGrow[K comparable](t *testing.T, keys []K) {
   269  	n := uint32(len(keys))
   270  	m := NewMap[K, int](n / 10)
   271  	for i, key := range keys {
   272  		m.Put(key, i)
   273  	}
   274  	for i, key := range keys {
   275  		act, ok := m.Get(key)
   276  		assert.True(t, ok)
   277  		assert.Equal(t, i, act)
   278  	}
   279  }
   280  
   281  func testSwissMapCapacity[K comparable](t *testing.T, gen func(n int) []K) {
   282  	// Capacity() behavior depends on |groupSize|
   283  	// which varies by processor architecture.
   284  	caps := []uint32{
   285  		1 * maxAvgGroupLoad,
   286  		2 * maxAvgGroupLoad,
   287  		3 * maxAvgGroupLoad,
   288  		4 * maxAvgGroupLoad,
   289  		5 * maxAvgGroupLoad,
   290  		10 * maxAvgGroupLoad,
   291  		25 * maxAvgGroupLoad,
   292  		50 * maxAvgGroupLoad,
   293  		100 * maxAvgGroupLoad,
   294  	}
   295  	for _, c := range caps {
   296  		m := NewMap[K, K](c)
   297  		assert.Equal(t, int(c), m.Capacity())
   298  		keys := gen(rand.Intn(int(c)))
   299  		for _, k := range keys {
   300  			m.Put(k, k)
   301  		}
   302  		assert.Equal(t, int(c)-len(keys), m.Capacity())
   303  		assert.Equal(t, int(c), m.Count()+m.Capacity())
   304  	}
   305  }
   306  
   307  func testProbeStats[K comparable](t *testing.T, keys []K) {
   308  	runTest := func(load float32) {
   309  		n := uint32(len(keys))
   310  		sz, k := loadFactorSample(n, load)
   311  		m := NewMap[K, int](sz)
   312  		for i, key := range keys[:k] {
   313  			m.Put(key, i)
   314  		}
   315  		// todo: assert stat invariants?
   316  		stats := getProbeStats(t, m, keys)
   317  		t.Log(fmtProbeStats(stats))
   318  	}
   319  	t.Run("load_factor=0.5", func(t *testing.T) {
   320  		runTest(0.5)
   321  	})
   322  	t.Run("load_factor=0.75", func(t *testing.T) {
   323  		runTest(0.75)
   324  	})
   325  	t.Run("load_factor=max", func(t *testing.T) {
   326  		runTest(maxLoadFactor)
   327  	})
   328  }
   329  
   330  // calculates the sample size and map size necessary to
   331  // create a load factor of |load| given |n| data points
   332  func loadFactorSample(n uint32, targetLoad float32) (mapSz, sampleSz uint32) {
   333  	if targetLoad > maxLoadFactor {
   334  		targetLoad = maxLoadFactor
   335  	}
   336  	// tables are assumed to be power of two
   337  	sampleSz = uint32(float32(n) * targetLoad)
   338  	mapSz = uint32(float32(n) * maxLoadFactor)
   339  	return
   340  }
   341  
   342  type probeStats struct {
   343  	groups     uint32
   344  	loadFactor float32
   345  	presentCnt uint32
   346  	presentMin uint32
   347  	presentMax uint32
   348  	presentAvg float32
   349  	absentCnt  uint32
   350  	absentMin  uint32
   351  	absentMax  uint32
   352  	absentAvg  float32
   353  }
   354  
   355  func fmtProbeStats(s probeStats) string {
   356  	g := fmt.Sprintf("groups=%d load=%f\n", s.groups, s.loadFactor)
   357  	p := fmt.Sprintf("present(n=%d): min=%d max=%d avg=%f\n",
   358  		s.presentCnt, s.presentMin, s.presentMax, s.presentAvg)
   359  	a := fmt.Sprintf("absent(n=%d):  min=%d max=%d avg=%f\n",
   360  		s.absentCnt, s.absentMin, s.absentMax, s.absentAvg)
   361  	return g + p + a
   362  }
   363  
   364  func getProbeLength[K comparable, V any](t *testing.T, m *Map[K, V], key K) (length uint32, ok bool) {
   365  	var end uint32
   366  	hi, lo := splitHash(m.hash.Hash(key))
   367  	start := probeStart(hi, len(m.groups))
   368  	end, _, ok = m.find(key, hi, lo)
   369  	if end < start { // wrapped
   370  		end += uint32(len(m.groups))
   371  	}
   372  	length = (end - start) + 1
   373  	require.True(t, length > 0)
   374  	return
   375  }
   376  
   377  func getProbeStats[K comparable, V any](t *testing.T, m *Map[K, V], keys []K) (stats probeStats) {
   378  	stats.groups = uint32(len(m.groups))
   379  	stats.loadFactor = m.loadFactor()
   380  	var presentSum, absentSum float32
   381  	stats.presentMin = math.MaxInt32
   382  	stats.absentMin = math.MaxInt32
   383  	for _, key := range keys {
   384  		l, ok := getProbeLength(t, m, key)
   385  		if ok {
   386  			stats.presentCnt++
   387  			presentSum += float32(l)
   388  			if stats.presentMin > l {
   389  				stats.presentMin = l
   390  			}
   391  			if stats.presentMax < l {
   392  				stats.presentMax = l
   393  			}
   394  		} else {
   395  			stats.absentCnt++
   396  			absentSum += float32(l)
   397  			if stats.absentMin > l {
   398  				stats.absentMin = l
   399  			}
   400  			if stats.absentMax < l {
   401  				stats.absentMax = l
   402  			}
   403  		}
   404  	}
   405  	if stats.presentCnt == 0 {
   406  		stats.presentMin = 0
   407  	} else {
   408  		stats.presentAvg = presentSum / float32(stats.presentCnt)
   409  	}
   410  	if stats.absentCnt == 0 {
   411  		stats.absentMin = 0
   412  	} else {
   413  		stats.absentAvg = absentSum / float32(stats.absentCnt)
   414  	}
   415  	return
   416  }
   417  
   418  func TestNumGroups(t *testing.T) {
   419  	assert.Equal(t, expected(0), numGroups(0))
   420  	assert.Equal(t, expected(1), numGroups(1))
   421  	// max load factor 0.875
   422  	assert.Equal(t, expected(14), numGroups(14))
   423  	assert.Equal(t, expected(15), numGroups(15))
   424  	assert.Equal(t, expected(28), numGroups(28))
   425  	assert.Equal(t, expected(29), numGroups(29))
   426  	assert.Equal(t, expected(56), numGroups(56))
   427  	assert.Equal(t, expected(57), numGroups(57))
   428  }
   429  
   430  func expected(x int) (groups uint32) {
   431  	groups = uint32(math.Ceil(float64(x) / float64(maxAvgGroupLoad)))
   432  	if groups == 0 {
   433  		groups = 1
   434  	}
   435  	return
   436  }