github.com/dolthub/swiss@v0.2.2-0.20240312182618-f4b2babd2bc1/map_test.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package swiss 16 17 import ( 18 "fmt" 19 "math" 20 "math/rand" 21 "testing" 22 23 "github.com/stretchr/testify/require" 24 25 "github.com/stretchr/testify/assert" 26 ) 27 28 func TestSwissMap(t *testing.T) { 29 t.Run("strings=0", func(t *testing.T) { 30 testSwissMap(t, genStringData(16, 0)) 31 }) 32 t.Run("strings=100", func(t *testing.T) { 33 testSwissMap(t, genStringData(16, 100)) 34 }) 35 t.Run("strings=1000", func(t *testing.T) { 36 testSwissMap(t, genStringData(16, 1000)) 37 }) 38 t.Run("strings=10_000", func(t *testing.T) { 39 testSwissMap(t, genStringData(16, 10_000)) 40 }) 41 t.Run("strings=100_000", func(t *testing.T) { 42 testSwissMap(t, genStringData(16, 100_000)) 43 }) 44 t.Run("uint32=0", func(t *testing.T) { 45 testSwissMap(t, genUint32Data(0)) 46 }) 47 t.Run("uint32=100", func(t *testing.T) { 48 testSwissMap(t, genUint32Data(100)) 49 }) 50 t.Run("uint32=1000", func(t *testing.T) { 51 testSwissMap(t, genUint32Data(1000)) 52 }) 53 t.Run("uint32=10_000", func(t *testing.T) { 54 testSwissMap(t, genUint32Data(10_000)) 55 }) 56 t.Run("uint32=100_000", func(t *testing.T) { 57 testSwissMap(t, genUint32Data(100_000)) 58 }) 59 t.Run("string capacity", func(t *testing.T) { 60 testSwissMapCapacity(t, func(n int) []string { 61 return genStringData(16, n) 62 }) 63 }) 64 t.Run("uint32 capacity", func(t *testing.T) { 65 testSwissMapCapacity(t, genUint32Data) 66 }) 67 } 68 69 func testSwissMap[K comparable](t *testing.T, keys []K) { 70 // sanity check 71 require.Equal(t, len(keys), len(uniq(keys)), keys) 72 t.Run("put", func(t *testing.T) { 73 testMapPut(t, keys) 74 }) 75 t.Run("has", func(t *testing.T) { 76 testMapHas(t, keys) 77 }) 78 t.Run("get", func(t *testing.T) { 79 testMapGet(t, keys) 80 }) 81 t.Run("delete", func(t *testing.T) { 82 testMapDelete(t, keys) 83 }) 84 t.Run("clear", func(t *testing.T) { 85 testMapClear(t, keys) 86 }) 87 t.Run("iter", func(t *testing.T) { 88 testMapIter(t, keys) 89 }) 90 t.Run("grow", func(t *testing.T) { 91 testMapGrow(t, keys) 92 }) 93 t.Run("probe stats", func(t *testing.T) { 94 testProbeStats(t, keys) 95 }) 96 } 97 98 func uniq[K comparable](keys []K) []K { 99 s := make(map[K]struct{}, len(keys)) 100 for _, k := range keys { 101 s[k] = struct{}{} 102 } 103 u := make([]K, 0, len(keys)) 104 for k := range s { 105 u = append(u, k) 106 } 107 return u 108 } 109 110 func genStringData(size, count int) (keys []string) { 111 src := rand.New(rand.NewSource(int64(size * count))) 112 letters := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") 113 r := make([]rune, size*count) 114 for i := range r { 115 r[i] = letters[src.Intn(len(letters))] 116 } 117 keys = make([]string, count) 118 for i := range keys { 119 keys[i] = string(r[:size]) 120 r = r[size:] 121 } 122 return 123 } 124 125 func genUint32Data(count int) (keys []uint32) { 126 keys = make([]uint32, count) 127 var x uint32 128 for i := range keys { 129 x += (rand.Uint32() % 128) + 1 130 keys[i] = x 131 } 132 return 133 } 134 135 func testMapPut[K comparable](t *testing.T, keys []K) { 136 m := NewMap[K, int](uint32(len(keys))) 137 assert.Equal(t, 0, m.Count()) 138 for i, key := range keys { 139 m.Put(key, i) 140 } 141 assert.Equal(t, len(keys), m.Count()) 142 // overwrite 143 for i, key := range keys { 144 m.Put(key, -i) 145 } 146 assert.Equal(t, len(keys), m.Count()) 147 for i, key := range keys { 148 act, ok := m.Get(key) 149 assert.True(t, ok) 150 assert.Equal(t, -i, act) 151 } 152 assert.Equal(t, len(keys), int(m.resident)) 153 } 154 155 func testMapHas[K comparable](t *testing.T, keys []K) { 156 m := NewMap[K, int](uint32(len(keys))) 157 for i, key := range keys { 158 m.Put(key, i) 159 } 160 for _, key := range keys { 161 ok := m.Has(key) 162 assert.True(t, ok) 163 } 164 } 165 166 func testMapGet[K comparable](t *testing.T, keys []K) { 167 m := NewMap[K, int](uint32(len(keys))) 168 for i, key := range keys { 169 m.Put(key, i) 170 } 171 for i, key := range keys { 172 act, ok := m.Get(key) 173 assert.True(t, ok) 174 assert.Equal(t, i, act) 175 } 176 } 177 178 func testMapDelete[K comparable](t *testing.T, keys []K) { 179 m := NewMap[K, int](uint32(len(keys))) 180 assert.Equal(t, 0, m.Count()) 181 for i, key := range keys { 182 m.Put(key, i) 183 } 184 assert.Equal(t, len(keys), m.Count()) 185 for _, key := range keys { 186 m.Delete(key) 187 ok := m.Has(key) 188 assert.False(t, ok) 189 } 190 assert.Equal(t, 0, m.Count()) 191 // put keys back after deleting them 192 for i, key := range keys { 193 m.Put(key, i) 194 } 195 assert.Equal(t, len(keys), m.Count()) 196 } 197 198 func testMapClear[K comparable](t *testing.T, keys []K) { 199 m := NewMap[K, int](0) 200 assert.Equal(t, 0, m.Count()) 201 for i, key := range keys { 202 m.Put(key, i) 203 } 204 assert.Equal(t, len(keys), m.Count()) 205 m.Clear() 206 assert.Equal(t, 0, m.Count()) 207 for _, key := range keys { 208 ok := m.Has(key) 209 assert.False(t, ok) 210 _, ok = m.Get(key) 211 assert.False(t, ok) 212 } 213 var calls int 214 m.Iter(func(k K, v int) (stop bool) { 215 calls++ 216 return 217 }) 218 assert.Equal(t, 0, calls) 219 220 // Assert that the map was actually cleared... 221 var k K 222 for _, g := range m.groups { 223 for i := range g.keys { 224 assert.Equal(t, k, g.keys[i]) 225 assert.Equal(t, 0, g.values[i]) 226 } 227 } 228 } 229 230 func testMapIter[K comparable](t *testing.T, keys []K) { 231 m := NewMap[K, int](uint32(len(keys))) 232 for i, key := range keys { 233 m.Put(key, i) 234 } 235 visited := make(map[K]uint, len(keys)) 236 m.Iter(func(k K, v int) (stop bool) { 237 visited[k] = 0 238 stop = true 239 return 240 }) 241 if len(keys) == 0 { 242 assert.Equal(t, len(visited), 0) 243 } else { 244 assert.Equal(t, len(visited), 1) 245 } 246 for _, k := range keys { 247 visited[k] = 0 248 } 249 m.Iter(func(k K, v int) (stop bool) { 250 visited[k]++ 251 return 252 }) 253 for _, c := range visited { 254 assert.Equal(t, c, uint(1)) 255 } 256 // mutate on iter 257 m.Iter(func(k K, v int) (stop bool) { 258 m.Put(k, -v) 259 return 260 }) 261 for i, key := range keys { 262 act, ok := m.Get(key) 263 assert.True(t, ok) 264 assert.Equal(t, -i, act) 265 } 266 } 267 268 func testMapGrow[K comparable](t *testing.T, keys []K) { 269 n := uint32(len(keys)) 270 m := NewMap[K, int](n / 10) 271 for i, key := range keys { 272 m.Put(key, i) 273 } 274 for i, key := range keys { 275 act, ok := m.Get(key) 276 assert.True(t, ok) 277 assert.Equal(t, i, act) 278 } 279 } 280 281 func testSwissMapCapacity[K comparable](t *testing.T, gen func(n int) []K) { 282 // Capacity() behavior depends on |groupSize| 283 // which varies by processor architecture. 284 caps := []uint32{ 285 1 * maxAvgGroupLoad, 286 2 * maxAvgGroupLoad, 287 3 * maxAvgGroupLoad, 288 4 * maxAvgGroupLoad, 289 5 * maxAvgGroupLoad, 290 10 * maxAvgGroupLoad, 291 25 * maxAvgGroupLoad, 292 50 * maxAvgGroupLoad, 293 100 * maxAvgGroupLoad, 294 } 295 for _, c := range caps { 296 m := NewMap[K, K](c) 297 assert.Equal(t, int(c), m.Capacity()) 298 keys := gen(rand.Intn(int(c))) 299 for _, k := range keys { 300 m.Put(k, k) 301 } 302 assert.Equal(t, int(c)-len(keys), m.Capacity()) 303 assert.Equal(t, int(c), m.Count()+m.Capacity()) 304 } 305 } 306 307 func testProbeStats[K comparable](t *testing.T, keys []K) { 308 runTest := func(load float32) { 309 n := uint32(len(keys)) 310 sz, k := loadFactorSample(n, load) 311 m := NewMap[K, int](sz) 312 for i, key := range keys[:k] { 313 m.Put(key, i) 314 } 315 // todo: assert stat invariants? 316 stats := getProbeStats(t, m, keys) 317 t.Log(fmtProbeStats(stats)) 318 } 319 t.Run("load_factor=0.5", func(t *testing.T) { 320 runTest(0.5) 321 }) 322 t.Run("load_factor=0.75", func(t *testing.T) { 323 runTest(0.75) 324 }) 325 t.Run("load_factor=max", func(t *testing.T) { 326 runTest(maxLoadFactor) 327 }) 328 } 329 330 // calculates the sample size and map size necessary to 331 // create a load factor of |load| given |n| data points 332 func loadFactorSample(n uint32, targetLoad float32) (mapSz, sampleSz uint32) { 333 if targetLoad > maxLoadFactor { 334 targetLoad = maxLoadFactor 335 } 336 // tables are assumed to be power of two 337 sampleSz = uint32(float32(n) * targetLoad) 338 mapSz = uint32(float32(n) * maxLoadFactor) 339 return 340 } 341 342 type probeStats struct { 343 groups uint32 344 loadFactor float32 345 presentCnt uint32 346 presentMin uint32 347 presentMax uint32 348 presentAvg float32 349 absentCnt uint32 350 absentMin uint32 351 absentMax uint32 352 absentAvg float32 353 } 354 355 func fmtProbeStats(s probeStats) string { 356 g := fmt.Sprintf("groups=%d load=%f\n", s.groups, s.loadFactor) 357 p := fmt.Sprintf("present(n=%d): min=%d max=%d avg=%f\n", 358 s.presentCnt, s.presentMin, s.presentMax, s.presentAvg) 359 a := fmt.Sprintf("absent(n=%d): min=%d max=%d avg=%f\n", 360 s.absentCnt, s.absentMin, s.absentMax, s.absentAvg) 361 return g + p + a 362 } 363 364 func getProbeLength[K comparable, V any](t *testing.T, m *Map[K, V], key K) (length uint32, ok bool) { 365 var end uint32 366 hi, lo := splitHash(m.hash.Hash(key)) 367 start := probeStart(hi, len(m.groups)) 368 end, _, ok = m.find(key, hi, lo) 369 if end < start { // wrapped 370 end += uint32(len(m.groups)) 371 } 372 length = (end - start) + 1 373 require.True(t, length > 0) 374 return 375 } 376 377 func getProbeStats[K comparable, V any](t *testing.T, m *Map[K, V], keys []K) (stats probeStats) { 378 stats.groups = uint32(len(m.groups)) 379 stats.loadFactor = m.loadFactor() 380 var presentSum, absentSum float32 381 stats.presentMin = math.MaxInt32 382 stats.absentMin = math.MaxInt32 383 for _, key := range keys { 384 l, ok := getProbeLength(t, m, key) 385 if ok { 386 stats.presentCnt++ 387 presentSum += float32(l) 388 if stats.presentMin > l { 389 stats.presentMin = l 390 } 391 if stats.presentMax < l { 392 stats.presentMax = l 393 } 394 } else { 395 stats.absentCnt++ 396 absentSum += float32(l) 397 if stats.absentMin > l { 398 stats.absentMin = l 399 } 400 if stats.absentMax < l { 401 stats.absentMax = l 402 } 403 } 404 } 405 if stats.presentCnt == 0 { 406 stats.presentMin = 0 407 } else { 408 stats.presentAvg = presentSum / float32(stats.presentCnt) 409 } 410 if stats.absentCnt == 0 { 411 stats.absentMin = 0 412 } else { 413 stats.absentAvg = absentSum / float32(stats.absentCnt) 414 } 415 return 416 } 417 418 func TestNumGroups(t *testing.T) { 419 assert.Equal(t, expected(0), numGroups(0)) 420 assert.Equal(t, expected(1), numGroups(1)) 421 // max load factor 0.875 422 assert.Equal(t, expected(14), numGroups(14)) 423 assert.Equal(t, expected(15), numGroups(15)) 424 assert.Equal(t, expected(28), numGroups(28)) 425 assert.Equal(t, expected(29), numGroups(29)) 426 assert.Equal(t, expected(56), numGroups(56)) 427 assert.Equal(t, expected(57), numGroups(57)) 428 } 429 430 func expected(x int) (groups uint32) { 431 groups = uint32(math.Ceil(float64(x) / float64(maxAvgGroupLoad))) 432 if groups == 0 { 433 groups = 1 434 } 435 return 436 }