github.com/dgraph-io/ristretto@v0.1.2-0.20240116140435-c67e07994f91/z/btree_test.go (about)

     1  /*
     2   * Copyright 2020 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package z
    18  
    19  import (
    20  	"fmt"
    21  	"math"
    22  	"math/rand"
    23  	"os"
    24  	"path/filepath"
    25  	"sort"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/dgraph-io/ristretto/z/simd"
    30  	"github.com/dustin/go-humanize"
    31  	"github.com/stretchr/testify/require"
    32  )
    33  
    34  var tmp int
    35  
    36  func setPageSize(sz int) {
    37  	pageSize = sz
    38  	maxKeys = (pageSize / 16) - 1
    39  }
    40  
    41  func TestTree(t *testing.T) {
    42  	bt := NewTree("TestTree")
    43  	defer func() { require.NoError(t, bt.Close()) }()
    44  
    45  	N := uint64(256 * 256)
    46  	for i := uint64(1); i < N; i++ {
    47  		bt.Set(i, i)
    48  	}
    49  	for i := uint64(1); i < N; i++ {
    50  		require.Equal(t, i, bt.Get(i))
    51  	}
    52  
    53  	bt.DeleteBelow(100)
    54  	for i := uint64(1); i < 100; i++ {
    55  		require.Equal(t, uint64(0), bt.Get(i))
    56  	}
    57  	for i := uint64(100); i < N; i++ {
    58  		require.Equal(t, i, bt.Get(i))
    59  	}
    60  }
    61  
    62  func TestTreePersistent(t *testing.T) {
    63  	dir, err := os.MkdirTemp("", "")
    64  	require.NoError(t, err)
    65  	defer os.RemoveAll(dir)
    66  	path := filepath.Join(dir, "tree.buf")
    67  
    68  	// Create a tree and validate the data.
    69  	bt1, err := NewTreePersistent(path)
    70  	require.NoError(t, err)
    71  	N := uint64(64 << 10)
    72  	for i := uint64(1); i < N; i++ {
    73  		bt1.Set(i, i*2)
    74  	}
    75  	for i := uint64(1); i < N; i++ {
    76  		require.Equal(t, i*2, bt1.Get(i))
    77  	}
    78  	bt1Stats := bt1.Stats()
    79  	require.NoError(t, bt1.Close())
    80  
    81  	// Reopen tree and validate the data.
    82  	bt2, err := NewTreePersistent(path)
    83  	require.NoError(t, err)
    84  	require.Equal(t, bt2.freePage, bt1.freePage)
    85  	require.Equal(t, bt2.nextPage, bt1.nextPage)
    86  	bt2Stats := bt2.Stats()
    87  	// When reopening a tree, the allocated size becomes the file size.
    88  	// We don't need to compare this, it doesn't change anything in the tree.
    89  	bt2Stats.Allocated = bt1Stats.Allocated
    90  	require.Equal(t, bt1Stats, bt2Stats)
    91  	for i := uint64(1); i < N; i++ {
    92  		require.Equal(t, i*2, bt2.Get(i))
    93  	}
    94  	// Delete all the data. This will change the value of bt.freePage.
    95  	bt2.DeleteBelow(math.MaxUint64)
    96  	bt2Stats = bt2.Stats()
    97  	require.NoError(t, bt2.Close())
    98  
    99  	// Reopen tree and validate the data.
   100  	bt3, err := NewTreePersistent(path)
   101  	require.NoError(t, err)
   102  	require.Equal(t, bt2.freePage, bt3.freePage)
   103  	require.Equal(t, bt2.nextPage, bt3.nextPage)
   104  	bt3Stats := bt3.Stats()
   105  	bt3Stats.Allocated = bt2Stats.Allocated
   106  	require.Equal(t, bt2Stats, bt3Stats)
   107  	require.NoError(t, bt3.Close())
   108  }
   109  
   110  func TestTreeBasic(t *testing.T) {
   111  	setAndGet := func() {
   112  		bt := NewTree("TestTreeBasic")
   113  		defer func() { require.NoError(t, bt.Close()) }()
   114  
   115  		N := uint64(1 << 20)
   116  		mp := make(map[uint64]uint64)
   117  		for i := uint64(1); i < N; i++ {
   118  			key := uint64(rand.Int63n(1<<60) + 1)
   119  			bt.Set(key, key)
   120  			mp[key] = key
   121  		}
   122  		for k, v := range mp {
   123  			require.Equal(t, v, bt.Get(k))
   124  		}
   125  
   126  		stats := bt.Stats()
   127  		t.Logf("final stats: %+v\n", stats)
   128  	}
   129  	setAndGet()
   130  	defer setPageSize(os.Getpagesize())
   131  	setPageSize(16 << 5)
   132  	setAndGet()
   133  }
   134  
   135  func TestTreeReset(t *testing.T) {
   136  	bt := NewTree("TestTreeReset")
   137  	defer func() { require.NoError(t, bt.Close()) }()
   138  
   139  	N := 1 << 10
   140  	val := rand.Uint64()
   141  	for i := 0; i < N; i++ {
   142  		bt.Set(rand.Uint64(), val)
   143  	}
   144  
   145  	// Truncate it to small size that is less than pageSize.
   146  	bt.Reset()
   147  
   148  	stats := bt.Stats()
   149  	// Verify the tree stats.
   150  	require.Equal(t, 2, stats.NumPages)
   151  	require.Equal(t, 1, stats.NumLeafKeys)
   152  	require.Equal(t, 2*pageSize, stats.Bytes)
   153  	expectedOcc := float64(1) * 100 / float64(2*maxKeys)
   154  	require.InDelta(t, expectedOcc, stats.Occupancy, 0.01)
   155  	require.Zero(t, stats.NumPagesFree)
   156  	// Check if we can reinsert the data.
   157  	mp := make(map[uint64]uint64)
   158  	for i := 0; i < N; i++ {
   159  		k := rand.Uint64()
   160  		mp[k] = val
   161  		bt.Set(k, val)
   162  	}
   163  	for k, v := range mp {
   164  		require.Equal(t, v, bt.Get(k))
   165  	}
   166  }
   167  
   168  func TestTreeCycle(t *testing.T) {
   169  	bt := NewTree("TestTreeCycle")
   170  	defer func() { require.NoError(t, bt.Close()) }()
   171  
   172  	val := uint64(0)
   173  	for i := 0; i < 16; i++ {
   174  		for j := 0; j < 1e6+i*1e4; j++ {
   175  			val += 1
   176  			bt.Set(rand.Uint64(), val)
   177  		}
   178  		before := bt.Stats()
   179  		bt.DeleteBelow(val - 1e4)
   180  		after := bt.Stats()
   181  		t.Logf("Cycle %d Done. Before: %+v -> After: %+v\n", i, before, after)
   182  	}
   183  
   184  	bt.DeleteBelow(val)
   185  	stats := bt.Stats()
   186  	t.Logf("stats: %+v\n", stats)
   187  	require.LessOrEqual(t, stats.Occupancy, 1.0)
   188  	require.GreaterOrEqual(t, stats.NumPagesFree, int(float64(stats.NumPages)*0.95))
   189  }
   190  
   191  func TestTreeIterateKV(t *testing.T) {
   192  	bt := NewTree("TestTreeIterateKV")
   193  	defer func() { require.NoError(t, bt.Close()) }()
   194  
   195  	// Set entries: (i, i*10)
   196  	const n = uint64(1 << 20)
   197  	for i := uint64(1); i <= n; i++ {
   198  		bt.Set(i, i*10)
   199  	}
   200  
   201  	// Validate entries: (i, i*10)
   202  	// Set entries: (i, i*20)
   203  	count := uint64(0)
   204  	bt.IterateKV(func(k, v uint64) uint64 {
   205  		require.Equal(t, k*10, v)
   206  		count++
   207  		return k * 20
   208  	})
   209  	require.Equal(t, n, count)
   210  
   211  	// Validate entries: (i, i*20)
   212  	count = uint64(0)
   213  	bt.IterateKV(func(k, v uint64) uint64 {
   214  		require.Equal(t, k*20, v)
   215  		count++
   216  		return 0
   217  	})
   218  	require.Equal(t, n, count)
   219  }
   220  
   221  func TestOccupancyRatio(t *testing.T) {
   222  	// atmax 4 keys per node
   223  	setPageSize(16 * 5)
   224  	defer setPageSize(os.Getpagesize())
   225  	require.Equal(t, 4, maxKeys)
   226  
   227  	bt := NewTree("TestOccupancyRatio")
   228  	defer func() { require.NoError(t, bt.Close()) }()
   229  
   230  	expectedRatio := float64(1) * 100 / float64(2*maxKeys) // 2 because we'll have 2 pages.
   231  	stats := bt.Stats()
   232  	t.Logf("Expected ratio: %.2f. MaxKeys: %d. Stats: %+v\n", expectedRatio, maxKeys, stats)
   233  	require.InDelta(t, expectedRatio, stats.Occupancy, 0.01)
   234  	for i := uint64(1); i <= 3; i++ {
   235  		bt.Set(i, i)
   236  	}
   237  	// Tree structure will be:
   238  	//    [2,Max,_,_]
   239  	//  [1,2,_,_]  [3,Max,_,_]
   240  	expectedRatio = float64(4) * 100 / float64(3*maxKeys)
   241  	stats = bt.Stats()
   242  	t.Logf("Expected ratio: %.2f. MaxKeys: %d. Stats: %+v\n", expectedRatio, maxKeys, stats)
   243  	require.InDelta(t, expectedRatio, stats.Occupancy, 0.01)
   244  	bt.DeleteBelow(2)
   245  	// Tree structure will be:
   246  	//    [2,Max,_]
   247  	//  [2,_,_,_]  [3,Max,_,_]
   248  	expectedRatio = float64(3) * 100 / float64(3*maxKeys)
   249  	stats = bt.Stats()
   250  	t.Logf("Expected ratio: %.2f. MaxKeys: %d. Stats: %+v\n", expectedRatio, maxKeys, stats)
   251  	require.InDelta(t, expectedRatio, stats.Occupancy, 0.01)
   252  }
   253  
   254  func TestNode(t *testing.T) {
   255  	n := getNode(make([]byte, pageSize))
   256  	for i := uint64(1); i < 16; i *= 2 {
   257  		n.set(i, i)
   258  	}
   259  	n.print(0)
   260  	require.True(t, 0 == n.get(5))
   261  	n.set(5, 5)
   262  	n.print(0)
   263  
   264  	n.setBit(0)
   265  	require.False(t, n.isLeaf())
   266  	n.setBit(bitLeaf)
   267  	require.True(t, n.isLeaf())
   268  }
   269  
   270  func TestNodeBasic(t *testing.T) {
   271  	n := getNode(make([]byte, pageSize))
   272  	N := uint64(256)
   273  	mp := make(map[uint64]uint64)
   274  	for i := uint64(1); i < N; i++ {
   275  		key := uint64(rand.Int63n(1<<60) + 1)
   276  		n.set(key, key)
   277  		mp[key] = key
   278  	}
   279  	for k, v := range mp {
   280  		require.Equal(t, v, n.get(k))
   281  	}
   282  }
   283  
   284  func TestNode_MoveRight(t *testing.T) {
   285  	n := getNode(make([]byte, pageSize))
   286  	N := uint64(10)
   287  	for i := uint64(1); i < N; i++ {
   288  		n.set(i, i)
   289  	}
   290  	n.moveRight(5)
   291  	n.iterate(func(n node, i int) {
   292  		if i < 5 {
   293  			require.Equal(t, uint64(i+1), n.key(i))
   294  			require.Equal(t, uint64(i+1), n.val(i))
   295  		} else if i > 5 {
   296  			require.Equal(t, uint64(i), n.key(i))
   297  			require.Equal(t, uint64(i), n.val(i))
   298  		}
   299  	})
   300  }
   301  
   302  func TestNodeCompact(t *testing.T) {
   303  	n := getNode(make([]byte, pageSize))
   304  	n.setBit(bitLeaf)
   305  	N := uint64(128)
   306  	mp := make(map[uint64]uint64)
   307  	for i := uint64(1); i < N; i++ {
   308  		key := i
   309  		val := uint64(10)
   310  		if i%2 == 0 {
   311  			val = 20
   312  			mp[key] = 20
   313  		}
   314  		n.set(key, val)
   315  	}
   316  
   317  	require.Equal(t, int(N/2), n.compact(11))
   318  	for k, v := range mp {
   319  		require.Equal(t, v, n.get(k))
   320  	}
   321  	require.Equal(t, uint64(127), n.maxKey())
   322  }
   323  
   324  func BenchmarkPurge(b *testing.B) {
   325  	N := 16 << 20
   326  	b.Run("go-mem", func(b *testing.B) {
   327  		m := make(map[uint64]uint64)
   328  		for i := 0; i < N; i++ {
   329  			m[rand.Uint64()] = uint64(i)
   330  		}
   331  	})
   332  
   333  	b.Run("btree", func(b *testing.B) {
   334  		start := time.Now()
   335  		bt := NewTree("BenchmarkPurge")
   336  		defer func() { require.NoError(b, bt.Close()) }()
   337  		for i := 0; i < N; i++ {
   338  			bt.Set(rand.Uint64(), uint64(i))
   339  		}
   340  		b.Logf("Populate took: %s. stats: %+v\n", time.Since(start), bt.Stats())
   341  
   342  		start = time.Now()
   343  		before := bt.Stats()
   344  		bt.DeleteBelow(uint64(N - 1<<20))
   345  		after := bt.Stats()
   346  		b.Logf("Purge took: %s. Before: %+v After: %+v\n", time.Since(start), before, after)
   347  	})
   348  }
   349  
   350  func BenchmarkWrite(b *testing.B) {
   351  	b.Run("map", func(b *testing.B) {
   352  		mp := make(map[uint64]uint64)
   353  		for n := 0; n < b.N; n++ {
   354  			k := rand.Uint64()
   355  			mp[k] = k
   356  		}
   357  	})
   358  	b.Run("btree", func(b *testing.B) {
   359  		bt := NewTree("BenchmarkWrite")
   360  		defer func() { require.NoError(b, bt.Close()) }()
   361  		b.ResetTimer()
   362  		for n := 0; n < b.N; n++ {
   363  			k := rand.Uint64()
   364  			bt.Set(k, k)
   365  		}
   366  	})
   367  }
   368  
   369  // goos: linux
   370  // goarch: amd64
   371  // pkg: github.com/dgraph-io/ristretto/z
   372  // BenchmarkRead/map-4         	10845322	       109 ns/op
   373  // BenchmarkRead/btree-4       	 2744283	       430 ns/op
   374  // Cumulative for 10 runs.
   375  // name          time/op
   376  // Read/map-4    105ns ± 1%
   377  // Read/btree-4  422ns ± 1%
   378  func BenchmarkRead(b *testing.B) {
   379  	N := 10 << 20
   380  	mp := make(map[uint64]uint64)
   381  	for i := 0; i < N; i++ {
   382  		k := uint64(rand.Intn(2*N)) + 1
   383  		mp[k] = k
   384  	}
   385  	b.Run("map", func(b *testing.B) {
   386  		for i := 0; i < b.N; i++ {
   387  			k := uint64(rand.Intn(2 * N))
   388  			v, ok := mp[k]
   389  			_, _ = v, ok
   390  		}
   391  	})
   392  
   393  	bt := NewTree("BenchmarkRead")
   394  	defer func() { require.NoError(b, bt.Close()) }()
   395  	for i := 0; i < N; i++ {
   396  		k := uint64(rand.Intn(2*N)) + 1
   397  		bt.Set(k, k)
   398  	}
   399  	stats := bt.Stats()
   400  	fmt.Printf("Num pages: %d Size: %s\n", stats.NumPages,
   401  		humanize.IBytes(uint64(stats.Bytes)))
   402  	fmt.Println("Writes done.")
   403  
   404  	b.Run("btree", func(b *testing.B) {
   405  		for i := 0; i < b.N; i++ {
   406  			k := uint64(rand.Intn(2*N)) + 1
   407  			v := bt.Get(k)
   408  			_ = v
   409  		}
   410  	})
   411  }
   412  
   413  func BenchmarkSearch(b *testing.B) {
   414  	linear := func(n node, k uint64, N int) int {
   415  		for i := 0; i < N; i++ {
   416  			if ki := n.key(i); ki >= k {
   417  				return i
   418  			}
   419  		}
   420  		return N
   421  	}
   422  	binary := func(n node, k uint64, N int) int {
   423  		return sort.Search(N, func(i int) bool {
   424  			return n.key(i) >= k
   425  		})
   426  	}
   427  	unroll4 := func(n node, k uint64, N int) int {
   428  		if len(n[:2*N]) < 8 {
   429  			for i := 0; i < N; i++ {
   430  				if ki := n.key(i); ki >= k {
   431  					return i
   432  				}
   433  			}
   434  			return N
   435  		}
   436  		return int(simd.Search(n[:2*N], k))
   437  	}
   438  
   439  	jumpBy := []int{8, 16, 32, 64, 128, 196, 255}
   440  	for _, sz := range jumpBy {
   441  		f, err := os.CreateTemp(".", "tree")
   442  		require.NoError(b, err)
   443  
   444  		mf, err := OpenMmapFileUsing(f, pageSize, true)
   445  		if err != NewFile {
   446  			require.NoError(b, err)
   447  		}
   448  
   449  		n := getNode(mf.Data)
   450  		for i := 1; i <= sz; i++ {
   451  			n.set(uint64(i), uint64(i))
   452  		}
   453  
   454  		b.Run(fmt.Sprintf("linear-%d", sz), func(b *testing.B) {
   455  			for i := 0; i < b.N; i++ {
   456  				tmp = linear(n, math.MaxUint64, sz)
   457  			}
   458  		})
   459  		b.Run(fmt.Sprintf("binary-%d", sz), func(b *testing.B) {
   460  			for i := 0; i < b.N; i++ {
   461  				tmp = binary(n, uint64(sz), sz)
   462  			}
   463  		})
   464  		b.Run(fmt.Sprintf("unrolled-asm-%d", sz), func(b *testing.B) {
   465  			for i := 0; i < b.N; i++ {
   466  				tmp = unroll4(n, math.MaxUint64, sz)
   467  			}
   468  		})
   469  		mf.Close(0)
   470  		os.Remove(f.Name())
   471  	}
   472  }
   473  
   474  // This benchmark when run on dgus-delta, performed marginally better with threshold=32.
   475  // CustomSearch/sz-64_th-1-4     49.9ns ± 1% (fully binary)
   476  // CustomSearch/sz-64_th-16-4    63.3ns ± 0%
   477  // CustomSearch/sz-64_th-32-4    58.7ns ± 7%
   478  // CustomSearch/sz-64_th-64-4    63.9ns ± 7% (fully linear)
   479  
   480  // CustomSearch/sz-128_th-32-4   70.2ns ± 1%
   481  
   482  // CustomSearch/sz-255_th-1-4    77.3ns ± 0% (fully binary)
   483  // CustomSearch/sz-255_th-16-4   68.2ns ± 1%
   484  // CustomSearch/sz-255_th-32-4   67.0ns ± 7%
   485  // CustomSearch/sz-255_th-64-4   85.5ns ±19%
   486  // CustomSearch/sz-255_th-256-4   129ns ± 6% (fully linear)
   487  
   488  func BenchmarkCustomSearch(b *testing.B) {
   489  	mixed := func(n node, k uint64, N int, threshold int) int {
   490  		lo, hi := 0, N
   491  		// Reduce the search space using binary seach and then do linear search.
   492  		for hi-lo > threshold {
   493  			mid := (hi + lo) / 2
   494  			km := n.key(mid)
   495  			if k == km {
   496  				return mid
   497  			}
   498  			if k > km {
   499  				// key is greater than the key at mid, so move right.
   500  				lo = mid + 1
   501  			} else {
   502  				// else move left.
   503  				hi = mid
   504  			}
   505  		}
   506  		for i := lo; i <= hi; i++ {
   507  			if ki := n.key(i); ki >= k {
   508  				return i
   509  			}
   510  		}
   511  		return N
   512  	}
   513  
   514  	for _, sz := range []int{64, 128, 255} {
   515  		n := getNode(make([]byte, pageSize))
   516  		for i := 1; i <= sz; i++ {
   517  			n.set(uint64(i), uint64(i))
   518  		}
   519  
   520  		mk := sz + 1
   521  		for th := 1; th <= sz+1; th *= 2 {
   522  			b.Run(fmt.Sprintf("sz-%d th-%d", sz, th), func(b *testing.B) {
   523  				for i := 0; i < b.N; i++ {
   524  					k := uint64(rand.Intn(mk))
   525  					tmp = mixed(n, k, sz, th)
   526  				}
   527  			})
   528  		}
   529  	}
   530  }