github.com/zuoyebang/bitalostable@v1.0.1-0.20240229032404-e3b99a834294/snapshot_test.go (about)

     1  // Copyright 2012 The LevelDB-Go and Pebble and Bitalostored Authors. All rights reserved. Use
     2  // of this source code is governed by a BSD-style license that can be found in
     3  // the LICENSE file.
     4  
     5  package bitalostable
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"reflect"
    11  	"runtime"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/cockroachdb/errors"
    18  	"github.com/stretchr/testify/require"
    19  	"github.com/zuoyebang/bitalostable/internal/datadriven"
    20  	"github.com/zuoyebang/bitalostable/vfs"
    21  )
    22  
    23  func TestSnapshotListToSlice(t *testing.T) {
    24  	testCases := []struct {
    25  		vals []uint64
    26  	}{
    27  		{nil},
    28  		{[]uint64{1}},
    29  		{[]uint64{1, 2, 3}},
    30  		{[]uint64{3, 2, 1}},
    31  	}
    32  	for _, c := range testCases {
    33  		t.Run("", func(t *testing.T) {
    34  			var l snapshotList
    35  			l.init()
    36  			for _, v := range c.vals {
    37  				l.pushBack(&Snapshot{seqNum: v})
    38  			}
    39  			slice := l.toSlice()
    40  			if !reflect.DeepEqual(c.vals, slice) {
    41  				t.Fatalf("expected %d, but got %d", c.vals, slice)
    42  			}
    43  		})
    44  	}
    45  }
    46  
    47  func TestSnapshot(t *testing.T) {
    48  	var d *DB
    49  	var snapshots map[string]*Snapshot
    50  
    51  	close := func() {
    52  		for _, s := range snapshots {
    53  			require.NoError(t, s.Close())
    54  		}
    55  		snapshots = nil
    56  		if d != nil {
    57  			require.NoError(t, d.Close())
    58  			d = nil
    59  		}
    60  	}
    61  	defer close()
    62  
    63  	datadriven.RunTest(t, "testdata/snapshot", func(td *datadriven.TestData) string {
    64  		switch td.Cmd {
    65  		case "define":
    66  			close()
    67  
    68  			var err error
    69  			d, err = Open("", &Options{
    70  				FS: vfs.NewMem(),
    71  			})
    72  			if err != nil {
    73  				return err.Error()
    74  			}
    75  			snapshots = make(map[string]*Snapshot)
    76  
    77  			for _, line := range strings.Split(td.Input, "\n") {
    78  				parts := strings.Fields(line)
    79  				if len(parts) == 0 {
    80  					continue
    81  				}
    82  				var err error
    83  				switch parts[0] {
    84  				case "set":
    85  					if len(parts) != 3 {
    86  						return fmt.Sprintf("%s expects 2 arguments", parts[0])
    87  					}
    88  					err = d.Set([]byte(parts[1]), []byte(parts[2]), nil)
    89  				case "del":
    90  					if len(parts) != 2 {
    91  						return fmt.Sprintf("%s expects 1 argument", parts[0])
    92  					}
    93  					err = d.Delete([]byte(parts[1]), nil)
    94  				case "merge":
    95  					if len(parts) != 3 {
    96  						return fmt.Sprintf("%s expects 2 arguments", parts[0])
    97  					}
    98  					err = d.Merge([]byte(parts[1]), []byte(parts[2]), nil)
    99  				case "snapshot":
   100  					if len(parts) != 2 {
   101  						return fmt.Sprintf("%s expects 1 argument", parts[0])
   102  					}
   103  					snapshots[parts[1]] = d.NewSnapshot()
   104  				case "compact":
   105  					if len(parts) != 2 {
   106  						return fmt.Sprintf("%s expects 1 argument", parts[0])
   107  					}
   108  					keys := strings.Split(parts[1], "-")
   109  					if len(keys) != 2 {
   110  						return fmt.Sprintf("malformed key range: %s", parts[1])
   111  					}
   112  					err = d.Compact([]byte(keys[0]), []byte(keys[1]), false)
   113  				default:
   114  					return fmt.Sprintf("unknown op: %s", parts[0])
   115  				}
   116  				if err != nil {
   117  					return err.Error()
   118  				}
   119  			}
   120  			return ""
   121  
   122  		case "iter":
   123  			var iter *Iterator
   124  			if len(td.CmdArgs) == 1 {
   125  				if td.CmdArgs[0].Key != "snapshot" {
   126  					return fmt.Sprintf("unknown argument: %s", td.CmdArgs[0])
   127  				}
   128  				if len(td.CmdArgs[0].Vals) != 1 {
   129  					return fmt.Sprintf("%s expects 1 value: %s", td.CmdArgs[0].Key, td.CmdArgs[0])
   130  				}
   131  				name := td.CmdArgs[0].Vals[0]
   132  				snapshot := snapshots[name]
   133  				if snapshot == nil {
   134  					return fmt.Sprintf("unable to find snapshot \"%s\"", name)
   135  				}
   136  				iter = snapshot.NewIter(nil)
   137  			} else {
   138  				iter = d.NewIter(nil)
   139  			}
   140  			defer iter.Close()
   141  
   142  			var b bytes.Buffer
   143  			for _, line := range strings.Split(td.Input, "\n") {
   144  				parts := strings.Fields(line)
   145  				if len(parts) == 0 {
   146  					continue
   147  				}
   148  				switch parts[0] {
   149  				case "first":
   150  					iter.First()
   151  				case "last":
   152  					iter.Last()
   153  				case "seek-ge":
   154  					if len(parts) != 2 {
   155  						return "seek-ge <key>\n"
   156  					}
   157  					iter.SeekGE([]byte(strings.TrimSpace(parts[1])))
   158  				case "seek-lt":
   159  					if len(parts) != 2 {
   160  						return "seek-lt <key>\n"
   161  					}
   162  					iter.SeekLT([]byte(strings.TrimSpace(parts[1])))
   163  				case "next":
   164  					iter.Next()
   165  				case "prev":
   166  					iter.Prev()
   167  				default:
   168  					return fmt.Sprintf("unknown op: %s", parts[0])
   169  				}
   170  				if iter.Valid() {
   171  					fmt.Fprintf(&b, "%s:%s\n", iter.Key(), iter.Value())
   172  				} else if err := iter.Error(); err != nil {
   173  					fmt.Fprintf(&b, "err=%v\n", err)
   174  				} else {
   175  					fmt.Fprintf(&b, ".\n")
   176  				}
   177  			}
   178  			return b.String()
   179  
   180  		default:
   181  			return fmt.Sprintf("unknown command: %s", td.Cmd)
   182  		}
   183  	})
   184  }
   185  
   186  func TestSnapshotClosed(t *testing.T) {
   187  	d, err := Open("", &Options{
   188  		FS: vfs.NewMem(),
   189  	})
   190  	require.NoError(t, err)
   191  
   192  	catch := func(f func()) (err error) {
   193  		defer func() {
   194  			if r := recover(); r != nil {
   195  				err = r.(error)
   196  			}
   197  		}()
   198  		f()
   199  		return nil
   200  	}
   201  
   202  	snap := d.NewSnapshot()
   203  	require.NoError(t, snap.Close())
   204  	require.True(t, errors.Is(catch(func() { _ = snap.Close() }), ErrClosed))
   205  	require.True(t, errors.Is(catch(func() { _, _, _ = snap.Get(nil) }), ErrClosed))
   206  	require.True(t, errors.Is(catch(func() { snap.NewIter(nil) }), ErrClosed))
   207  
   208  	require.NoError(t, d.Close())
   209  }
   210  
   211  func TestSnapshotRangeDeletionStress(t *testing.T) {
   212  	const runs = 200
   213  	const middleKey = runs * runs
   214  
   215  	d, err := Open("", &Options{
   216  		FS: vfs.NewMem(),
   217  	})
   218  	require.NoError(t, err)
   219  
   220  	mkkey := func(k int) []byte {
   221  		return []byte(fmt.Sprintf("%08d", k))
   222  	}
   223  	v := []byte("hello world")
   224  
   225  	snapshots := make([]*Snapshot, 0, runs)
   226  	for r := 0; r < runs; r++ {
   227  		// We use a keyspace that is 2*runs*runs wide. In other words there are
   228  		// 2*runs sections of the keyspace, each with runs elements. On every
   229  		// run, we write to the r-th element of each section of the keyspace.
   230  		for i := 0; i < 2*runs; i++ {
   231  			err := d.Set(mkkey(runs*i+r), v, nil)
   232  			require.NoError(t, err)
   233  		}
   234  
   235  		// Now we delete some of the keyspace through a DeleteRange. We delete from
   236  		// the middle of the keyspace outwards. The keyspace is made of 2*runs
   237  		// sections, and we delete an additional two of these sections per run.
   238  		err := d.DeleteRange(mkkey(middleKey-runs*r), mkkey(middleKey+runs*r), nil)
   239  		require.NoError(t, err)
   240  
   241  		snapshots = append(snapshots, d.NewSnapshot())
   242  	}
   243  
   244  	// Check that all the snapshots contain the expected number of keys.
   245  	// Iterating over so many keys is slow, so do it in parallel.
   246  	var wg sync.WaitGroup
   247  	sem := make(chan struct{}, runtime.GOMAXPROCS(0))
   248  	for r := range snapshots {
   249  		wg.Add(1)
   250  		sem <- struct{}{}
   251  		go func(r int) {
   252  			defer func() {
   253  				<-sem
   254  				wg.Done()
   255  			}()
   256  
   257  			// Count the keys at this snapshot.
   258  			iter := snapshots[r].NewIter(nil)
   259  			var keysFound int
   260  			for iter.First(); iter.Valid(); iter.Next() {
   261  				keysFound++
   262  			}
   263  			err := firstError(iter.Error(), iter.Close())
   264  			if err != nil {
   265  				t.Error(err)
   266  				return
   267  			}
   268  
   269  			// At the time that this snapshot was taken, (r+1)*2*runs unique keys
   270  			// were Set (one in each of the 2*runs sections per run).  But this
   271  			// run also deleted the 2*r middlemost sections.  When this snapshot
   272  			// was taken, a Set to each of those sections had been made (r+1)
   273  			// times, so 2*r*(r+1) previously-set keys are now deleted.
   274  
   275  			keysExpected := (r+1)*2*runs - 2*r*(r+1)
   276  			if keysFound != keysExpected {
   277  				t.Errorf("%d: found %d keys, want %d", r, keysFound, keysExpected)
   278  			}
   279  			if err := snapshots[r].Close(); err != nil {
   280  				t.Error(err)
   281  			}
   282  		}(r)
   283  	}
   284  	wg.Wait()
   285  	require.NoError(t, d.Close())
   286  }
   287  
   288  // TestNewSnapshotRace tests atomicity of NewSnapshot.
   289  //
   290  // It tests for a regression of a previous race condition in which NewSnapshot
   291  // would retrieve the visible sequence number for a new snapshot before
   292  // locking the database mutex to add the snapshot. A write and flush that
   293  // that occurred between the reading of the sequence number and appending the
   294  // snapshot could drop keys required by the snapshot.
   295  func TestNewSnapshotRace(t *testing.T) {
   296  	const runs = 10
   297  	d, err := Open("", &Options{FS: vfs.NewMem()})
   298  	require.NoError(t, err)
   299  
   300  	v := []byte(`foo`)
   301  	ch := make(chan string)
   302  	var wg sync.WaitGroup
   303  	wg.Add(1)
   304  
   305  	go func() {
   306  		defer wg.Done()
   307  		for k := range ch {
   308  			if err := d.Set([]byte(k), v, nil); err != nil {
   309  				t.Error(err)
   310  				return
   311  			}
   312  			if err := d.Flush(); err != nil {
   313  				t.Error(err)
   314  				return
   315  			}
   316  		}
   317  	}()
   318  	for i := 0; i < runs; i++ {
   319  		// This main test goroutine sets `k` before creating a new snapshot.
   320  		// The key `k` should always be present within the snapshot.
   321  		k := fmt.Sprintf("key%06d", i)
   322  		require.NoError(t, d.Set([]byte(k), v, nil))
   323  
   324  		// Lock d.mu in another goroutine so that our call to NewSnapshot
   325  		// will need to contend for d.mu.
   326  		wg.Add(1)
   327  		locked := make(chan struct{})
   328  		go func() {
   329  			defer wg.Done()
   330  			d.mu.Lock()
   331  			close(locked)
   332  			time.Sleep(20 * time.Millisecond)
   333  			d.mu.Unlock()
   334  		}()
   335  		<-locked
   336  
   337  		// Tell the other goroutine to overwrite `k` with a later sequence
   338  		// number. It's indeterminate which key we'll read, but we should
   339  		// always read one of them.
   340  		ch <- k
   341  		s := d.NewSnapshot()
   342  		_, c, err := s.Get([]byte(k))
   343  		require.NoError(t, err)
   344  		require.NoError(t, c.Close())
   345  		require.NoError(t, s.Close())
   346  	}
   347  	close(ch)
   348  	wg.Wait()
   349  	require.NoError(t, d.Close())
   350  }