github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/kv/kvserver/protectedts/ptcache/cache_test.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package ptcache_test
    12  
    13  import (
    14  	"bytes"
    15  	"context"
    16  	"sort"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/cockroachdb/cockroach/pkg/base"
    21  	"github.com/cockroachdb/cockroach/pkg/keys"
    22  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver"
    23  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvserverbase"
    24  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts"
    25  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts/ptcache"
    26  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts/ptpb"
    27  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts/ptstorage"
    28  	"github.com/cockroachdb/cockroach/pkg/roachpb"
    29  	"github.com/cockroachdb/cockroach/pkg/sql/sqlutil"
    30  	"github.com/cockroachdb/cockroach/pkg/testutils"
    31  	"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
    32  	"github.com/cockroachdb/cockroach/pkg/testutils/testcluster"
    33  	"github.com/cockroachdb/cockroach/pkg/util/hlc"
    34  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    35  	"github.com/cockroachdb/cockroach/pkg/util/stop"
    36  	"github.com/cockroachdb/cockroach/pkg/util/syncutil"
    37  	"github.com/cockroachdb/cockroach/pkg/util/uuid"
    38  	"github.com/cockroachdb/errors"
    39  	"github.com/stretchr/testify/assert"
    40  	"github.com/stretchr/testify/require"
    41  )
    42  
    43  // TestCacheBasic exercises the basic behavior of the Cache.
    44  func TestCacheBasic(t *testing.T) {
    45  	defer leaktest.AfterTest(t)()
    46  	ctx := context.Background()
    47  	tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{})
    48  	defer tc.Stopper().Stop(ctx)
    49  	s := tc.Server(0)
    50  	p := ptstorage.WithDatabase(ptstorage.New(s.ClusterSettings(),
    51  		s.InternalExecutor().(sqlutil.InternalExecutor)), s.DB())
    52  
    53  	// Set the poll interval to be very short.
    54  	protectedts.PollInterval.Override(&s.ClusterSettings().SV, 500*time.Microsecond)
    55  
    56  	c := ptcache.New(ptcache.Config{
    57  		Settings: s.ClusterSettings(),
    58  		DB:       s.DB(),
    59  		Storage:  p,
    60  	})
    61  	require.NoError(t, c.Start(ctx, tc.Stopper()))
    62  
    63  	// Make sure that protected timestamp gets updated.
    64  	ts := waitForAsOfAfter(t, c, hlc.Timestamp{})
    65  
    66  	// Make sure that it gets updated again.
    67  	waitForAsOfAfter(t, c, ts)
    68  
    69  	// Then we'll add a record and make sure it gets seen.
    70  	sp := tableSpan(42)
    71  	r, createdAt := protect(t, tc.Server(0), p, sp)
    72  	testutils.SucceedsSoon(t, func() error {
    73  		var coveredBy []*ptpb.Record
    74  		seenTS := c.Iterate(ctx, sp.Key, sp.EndKey,
    75  			func(r *ptpb.Record) (wantMore bool) {
    76  				coveredBy = append(coveredBy, r)
    77  				return true
    78  			})
    79  		if len(coveredBy) == 0 {
    80  			assert.True(t, seenTS.Less(createdAt), "%v %v", seenTS, createdAt)
    81  			return errors.Errorf("expected %v to be covered", sp)
    82  		}
    83  		require.True(t, !seenTS.Less(createdAt), "%v %v", seenTS, createdAt)
    84  		require.EqualValues(t, []*ptpb.Record{r}, coveredBy)
    85  		return nil
    86  	})
    87  
    88  	// Then release the record and make sure that that gets seen.
    89  	require.Nil(t, p.Release(ctx, nil /* txn */, r.ID))
    90  	testutils.SucceedsSoon(t, func() error {
    91  		var coveredBy []*ptpb.Record
    92  		_ = c.Iterate(ctx, sp.Key, sp.EndKey,
    93  			func(r *ptpb.Record) (wantMore bool) {
    94  				coveredBy = append(coveredBy, r)
    95  				return true
    96  			})
    97  		if len(coveredBy) > 0 {
    98  			return errors.Errorf("expected %v not to be covered", sp)
    99  		}
   100  		return nil
   101  	})
   102  }
   103  
   104  func TestRefresh(t *testing.T) {
   105  	defer leaktest.AfterTest(t)()
   106  	ctx := context.Background()
   107  	st := &scanTracker{}
   108  	tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{
   109  		ServerArgs: base.TestServerArgs{
   110  			Knobs: base.TestingKnobs{
   111  				Store: &kvserver.StoreTestingKnobs{
   112  					TestingRequestFilter: kvserverbase.ReplicaRequestFilter(st.requestFilter),
   113  				},
   114  			},
   115  		},
   116  	})
   117  	defer tc.Stopper().Stop(ctx)
   118  	s := tc.Server(0)
   119  	p := ptstorage.WithDatabase(ptstorage.New(s.ClusterSettings(),
   120  		s.InternalExecutor().(sqlutil.InternalExecutor)), s.DB())
   121  
   122  	// Set the poll interval to be very long.
   123  	protectedts.PollInterval.Override(&s.ClusterSettings().SV, 500*time.Hour)
   124  
   125  	c := ptcache.New(ptcache.Config{
   126  		Settings: s.ClusterSettings(),
   127  		DB:       s.DB(),
   128  		Storage:  p,
   129  	})
   130  	require.NoError(t, c.Start(ctx, tc.Stopper()))
   131  	t.Run("already up-to-date", func(t *testing.T) {
   132  		ts := waitForAsOfAfter(t, c, hlc.Timestamp{})
   133  		st.resetCounters()
   134  		require.NoError(t, c.Refresh(ctx, ts))
   135  		st.verifyCounters(t, 0, 0) // already up to date
   136  	})
   137  	t.Run("needs refresh, no change", func(t *testing.T) {
   138  		ts := waitForAsOfAfter(t, c, hlc.Timestamp{})
   139  		require.NoError(t, c.Refresh(ctx, ts.Next()))
   140  		st.verifyCounters(t, 1, 0) // just need to scan meta
   141  	})
   142  	t.Run("needs refresh, with change", func(t *testing.T) {
   143  		_, createdAt := protect(t, s, p, metaTableSpan)
   144  		st.resetCounters()
   145  		require.NoError(t, c.Refresh(ctx, createdAt))
   146  		st.verifyCounters(t, 2, 1) // need to scan meta and then scan everything
   147  	})
   148  	t.Run("cancelation returns early", func(t *testing.T) {
   149  		withCancel, cancel := context.WithCancel(ctx)
   150  		defer cancel()
   151  		done := make(chan struct{})
   152  		st.setFilter(func(ba roachpb.BatchRequest) *roachpb.Error {
   153  			if scanReq, ok := ba.GetArg(roachpb.Scan); ok {
   154  				scan := scanReq.(*roachpb.ScanRequest)
   155  				if scan.Span().Overlaps(metaTableSpan) {
   156  					<-done
   157  				}
   158  			}
   159  			return nil
   160  		})
   161  		go func() { time.Sleep(time.Millisecond); cancel() }()
   162  		require.EqualError(t, c.Refresh(withCancel, s.Clock().Now()),
   163  			context.Canceled.Error())
   164  		close(done)
   165  	})
   166  	t.Run("error propagates while fetching metadata", func(t *testing.T) {
   167  		st.setFilter(func(ba roachpb.BatchRequest) *roachpb.Error {
   168  			if scanReq, ok := ba.GetArg(roachpb.Scan); ok {
   169  				scan := scanReq.(*roachpb.ScanRequest)
   170  				if scan.Span().Overlaps(metaTableSpan) {
   171  					return roachpb.NewError(errors.New("boom"))
   172  				}
   173  			}
   174  			return nil
   175  		})
   176  		defer st.setFilter(nil)
   177  		require.Regexp(t, "boom", c.Refresh(ctx, s.Clock().Now()).Error())
   178  	})
   179  	t.Run("error propagates while fetching records", func(t *testing.T) {
   180  		protect(t, s, p, metaTableSpan)
   181  		st.setFilter(func(ba roachpb.BatchRequest) *roachpb.Error {
   182  			if scanReq, ok := ba.GetArg(roachpb.Scan); ok {
   183  				scan := scanReq.(*roachpb.ScanRequest)
   184  				if scan.Span().Overlaps(recordsTableSpan) {
   185  					return roachpb.NewError(errors.New("boom"))
   186  				}
   187  			}
   188  			return nil
   189  		})
   190  		defer st.setFilter(nil)
   191  		require.Regexp(t, "boom", c.Refresh(ctx, s.Clock().Now()).Error())
   192  	})
   193  	t.Run("Iterate does not hold mutex", func(t *testing.T) {
   194  		inIterate := make(chan chan struct{})
   195  		rec, createdAt := protect(t, s, p, metaTableSpan)
   196  		require.NoError(t, c.Refresh(ctx, createdAt))
   197  		go c.Iterate(ctx, keys.MinKey, keys.MaxKey, func(r *ptpb.Record) (wantMore bool) {
   198  			if r.ID != rec.ID {
   199  				return true
   200  			}
   201  			// Make sure we see the record we created and use it to signal the main
   202  			// goroutine.
   203  			waitUntil := make(chan struct{})
   204  			inIterate <- waitUntil
   205  			<-waitUntil
   206  			defer close(inIterate)
   207  			return false
   208  		})
   209  		// Wait until we get to the record in iteration and pause, perform an
   210  		// operation, amd then refresh after it. This will demonstrate that the
   211  		// iteration call does not block concurrent refreshes.
   212  		ch := <-inIterate
   213  		require.NoError(t, p.Release(ctx, nil /* txn */, rec.ID))
   214  		require.NoError(t, c.Refresh(ctx, s.Clock().Now()))
   215  		// Signal the Iterate loop to exit and wait for it to close the channel.
   216  		close(ch)
   217  		<-inIterate
   218  	})
   219  }
   220  
   221  func TestStart(t *testing.T) {
   222  	defer leaktest.AfterTest(t)()
   223  	ctx := context.Background()
   224  	setup := func() (*testcluster.TestCluster, *ptcache.Cache) {
   225  		tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{})
   226  		s := tc.Server(0)
   227  		p := ptstorage.New(s.ClusterSettings(),
   228  			s.InternalExecutor().(sqlutil.InternalExecutor))
   229  		// Set the poll interval to be very long.
   230  		protectedts.PollInterval.Override(&s.ClusterSettings().SV, 500*time.Hour)
   231  		c := ptcache.New(ptcache.Config{
   232  			Settings: s.ClusterSettings(),
   233  			DB:       s.DB(),
   234  			Storage:  p,
   235  		})
   236  		return tc, c
   237  	}
   238  
   239  	t.Run("double start", func(t *testing.T) {
   240  		tc, c := setup()
   241  		defer tc.Stopper().Stop(ctx)
   242  		require.NoError(t, c.Start(ctx, tc.Stopper()))
   243  		require.EqualError(t, c.Start(ctx, tc.Stopper()),
   244  			"cannot start a Cache more than once")
   245  	})
   246  	t.Run("already stopped", func(t *testing.T) {
   247  		tc, c := setup()
   248  		tc.Stopper().Stop(ctx)
   249  		require.EqualError(t, c.Start(ctx, tc.Stopper()),
   250  			stop.ErrUnavailable.Error())
   251  	})
   252  }
   253  
   254  func TestQueryRecord(t *testing.T) {
   255  	defer leaktest.AfterTest(t)()
   256  	ctx := context.Background()
   257  	tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{})
   258  	defer tc.Stopper().Stop(ctx)
   259  	s := tc.Server(0)
   260  	p := ptstorage.WithDatabase(ptstorage.New(s.ClusterSettings(),
   261  		s.InternalExecutor().(sqlutil.InternalExecutor)), s.DB())
   262  	// Set the poll interval to be very long.
   263  	protectedts.PollInterval.Override(&s.ClusterSettings().SV, 500*time.Hour)
   264  	c := ptcache.New(ptcache.Config{
   265  		Settings: s.ClusterSettings(),
   266  		DB:       s.DB(),
   267  		Storage:  p,
   268  	})
   269  	require.NoError(t, c.Start(ctx, tc.Stopper()))
   270  
   271  	// Wait for the initial fetch.
   272  	waitForAsOfAfter(t, c, hlc.Timestamp{})
   273  	// Create two records.
   274  	sp42 := tableSpan(42)
   275  	r1, createdAt1 := protect(t, s, p, sp42)
   276  	r2, createdAt2 := protect(t, s, p, sp42)
   277  	// Ensure they both don't exist and that the read timestamps precede the
   278  	// create timestamps.
   279  	exists1, asOf := c.QueryRecord(ctx, r1.ID)
   280  	require.False(t, exists1)
   281  	require.True(t, asOf.Less(createdAt1))
   282  	exists2, asOf := c.QueryRecord(ctx, r2.ID)
   283  	require.False(t, exists2)
   284  	require.True(t, asOf.Less(createdAt2))
   285  	// Go refresh the state and make sure they both exist.
   286  	require.NoError(t, c.Refresh(ctx, createdAt2))
   287  	exists1, asOf = c.QueryRecord(ctx, r1.ID)
   288  	require.True(t, exists1)
   289  	require.True(t, !asOf.Less(createdAt1))
   290  	exists2, asOf = c.QueryRecord(ctx, r2.ID)
   291  	require.True(t, exists2)
   292  	require.True(t, !asOf.Less(createdAt2))
   293  	// Release 2 and then create 3.
   294  	require.NoError(t, p.Release(ctx, nil /* txn */, r2.ID))
   295  	r3, createdAt3 := protect(t, s, p, sp42)
   296  	exists2, asOf = c.QueryRecord(ctx, r2.ID)
   297  	require.True(t, exists2)
   298  	require.True(t, asOf.Less(createdAt3))
   299  	exists3, asOf := c.QueryRecord(ctx, r3.ID)
   300  	require.False(t, exists3)
   301  	require.True(t, asOf.Less(createdAt3))
   302  	// Go refresh the state and make sure 1 and 3 exist.
   303  	require.NoError(t, c.Refresh(ctx, createdAt3))
   304  	exists1, _ = c.QueryRecord(ctx, r1.ID)
   305  	require.True(t, exists1)
   306  	exists2, _ = c.QueryRecord(ctx, r2.ID)
   307  	require.False(t, exists2)
   308  	exists3, _ = c.QueryRecord(ctx, r3.ID)
   309  	require.True(t, exists3)
   310  }
   311  
   312  func TestIterate(t *testing.T) {
   313  	ctx := context.Background()
   314  	tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{})
   315  	defer tc.Stopper().Stop(ctx)
   316  	s := tc.Server(0)
   317  	p := ptstorage.WithDatabase(ptstorage.New(s.ClusterSettings(),
   318  		s.InternalExecutor().(sqlutil.InternalExecutor)), s.DB())
   319  
   320  	// Set the poll interval to be very long.
   321  	protectedts.PollInterval.Override(&s.ClusterSettings().SV, 500*time.Hour)
   322  
   323  	c := ptcache.New(ptcache.Config{
   324  		Settings: s.ClusterSettings(),
   325  		DB:       s.DB(),
   326  		Storage:  p,
   327  	})
   328  	require.NoError(t, c.Start(ctx, tc.Stopper()))
   329  
   330  	sp42 := tableSpan(42)
   331  	sp43 := tableSpan(43)
   332  	sp44 := tableSpan(44)
   333  	r1, _ := protect(t, s, p, sp42)
   334  	r2, _ := protect(t, s, p, sp43)
   335  	r3, _ := protect(t, s, p, sp44)
   336  	r4, _ := protect(t, s, p, sp42, sp43)
   337  	require.NoError(t, c.Refresh(ctx, s.Clock().Now()))
   338  	t.Run("all", func(t *testing.T) {
   339  		var recs records
   340  		c.Iterate(ctx, sp42.Key, sp44.EndKey, recs.accumulate)
   341  		require.EqualValues(t, recs.sorted(), (&records{r1, r2, r3, r4}).sorted())
   342  	})
   343  	t.Run("some", func(t *testing.T) {
   344  		var recs records
   345  		c.Iterate(ctx, sp42.Key, sp42.EndKey, recs.accumulate)
   346  		require.EqualValues(t, recs.sorted(), (&records{r1, r4}).sorted())
   347  	})
   348  	t.Run("none", func(t *testing.T) {
   349  		var recs records
   350  		c.Iterate(ctx, sp42.Key, sp42.EndKey, recs.accumulate)
   351  		require.EqualValues(t, recs.sorted(), (&records{r1, r4}).sorted())
   352  	})
   353  	t.Run("early return from iteration", func(t *testing.T) {
   354  		var called int
   355  		c.Iterate(ctx, sp42.Key, sp44.EndKey, func(*ptpb.Record) (wantMore bool) {
   356  			called++
   357  			return false
   358  		})
   359  		require.Equal(t, 1, called)
   360  	})
   361  }
   362  
   363  type records []*ptpb.Record
   364  
   365  func (recs *records) accumulate(r *ptpb.Record) (wantMore bool) {
   366  	(*recs) = append(*recs, r)
   367  	return true
   368  }
   369  
   370  func (recs *records) sorted() []*ptpb.Record {
   371  	sort.Slice(*recs, func(i, j int) bool {
   372  		return bytes.Compare((*recs)[i].ID[:], (*recs)[j].ID[:]) < 0
   373  	})
   374  	return *recs
   375  }
   376  
   377  func TestSettingChangedLeadsToFetch(t *testing.T) {
   378  	ctx := context.Background()
   379  	tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{})
   380  	defer tc.Stopper().Stop(ctx)
   381  	s := tc.Server(0)
   382  	p := ptstorage.WithDatabase(ptstorage.New(s.ClusterSettings(),
   383  		s.InternalExecutor().(sqlutil.InternalExecutor)), s.DB())
   384  
   385  	// Set the poll interval to be very long.
   386  	protectedts.PollInterval.Override(&s.ClusterSettings().SV, 500*time.Hour)
   387  
   388  	c := ptcache.New(ptcache.Config{
   389  		Settings: s.ClusterSettings(),
   390  		DB:       s.DB(),
   391  		Storage:  p,
   392  	})
   393  	require.NoError(t, c.Start(ctx, tc.Stopper()))
   394  
   395  	// Make sure that the initial state has been fetched.
   396  	ts := waitForAsOfAfter(t, c, hlc.Timestamp{})
   397  	// Make sure there isn't another rapid fetch.
   398  	// If there were a bug it might fail under stress.
   399  	time.Sleep(time.Millisecond)
   400  	_, asOf := c.QueryRecord(ctx, uuid.UUID{})
   401  	require.Equal(t, asOf, ts)
   402  	// Set the polling interval back to something very short.
   403  	protectedts.PollInterval.Override(&s.ClusterSettings().SV, 100*time.Microsecond)
   404  	// Ensure that the state is updated again soon.
   405  	waitForAsOfAfter(t, c, ts)
   406  }
   407  
   408  func waitForAsOfAfter(t *testing.T, c protectedts.Cache, ts hlc.Timestamp) (asOf hlc.Timestamp) {
   409  	testutils.SucceedsSoon(t, func() error {
   410  		var exists bool
   411  		exists, asOf = c.QueryRecord(context.Background(), uuid.UUID{})
   412  		require.False(t, exists)
   413  		if !ts.Less(asOf) {
   414  			return errors.Errorf("expected an update to occur")
   415  		}
   416  		return nil
   417  	})
   418  	return asOf
   419  }
   420  
   421  func tableSpan(tableID uint32) roachpb.Span {
   422  	return roachpb.Span{
   423  		Key:    keys.SystemSQLCodec.TablePrefix(tableID),
   424  		EndKey: keys.SystemSQLCodec.TablePrefix(tableID).PrefixEnd(),
   425  	}
   426  }
   427  
   428  func protect(
   429  	t *testing.T, s serverutils.TestServerInterface, p protectedts.Storage, spans ...roachpb.Span,
   430  ) (r *ptpb.Record, createdAt hlc.Timestamp) {
   431  	protectTS := s.Clock().Now()
   432  	r = &ptpb.Record{
   433  		ID:        uuid.MakeV4(),
   434  		Timestamp: protectTS,
   435  		Mode:      ptpb.PROTECT_AFTER,
   436  		Spans:     spans,
   437  	}
   438  	ctx := context.Background()
   439  	txn := s.DB().NewTxn(ctx, "test")
   440  	require.NoError(t, p.Protect(ctx, txn, r))
   441  	require.NoError(t, txn.Commit(ctx))
   442  	_, err := p.GetRecord(ctx, nil, r.ID)
   443  	require.NoError(t, err)
   444  	createdAt = txn.CommitTimestamp()
   445  	return r, createdAt
   446  }
   447  
   448  var (
   449  	metaTableSpan    = tableSpan(keys.ProtectedTimestampsMetaTableID)
   450  	recordsTableSpan = tableSpan(keys.ProtectedTimestampsRecordsTableID)
   451  )
   452  
   453  type scanTracker struct {
   454  	mu                syncutil.Mutex
   455  	metaTableScans    int
   456  	recordsTableScans int
   457  	filterFunc        func(ba roachpb.BatchRequest) *roachpb.Error
   458  }
   459  
   460  func (st *scanTracker) resetCounters() {
   461  	st.mu.Lock()
   462  	defer st.mu.Unlock()
   463  	st.metaTableScans, st.recordsTableScans = 0, 0
   464  }
   465  
   466  func (st *scanTracker) verifyCounters(t *testing.T, expMeta, expRecords int) {
   467  	t.Helper()
   468  	st.mu.Lock()
   469  	defer st.mu.Unlock()
   470  	require.Equal(t, expMeta, st.metaTableScans)
   471  	require.Equal(t, expRecords, st.recordsTableScans)
   472  }
   473  
   474  func (st *scanTracker) setFilter(f func(roachpb.BatchRequest) *roachpb.Error) {
   475  	st.mu.Lock()
   476  	defer st.mu.Unlock()
   477  	st.filterFunc = f
   478  }
   479  
   480  func (st *scanTracker) requestFilter(_ context.Context, ba roachpb.BatchRequest) *roachpb.Error {
   481  	st.mu.Lock()
   482  	defer st.mu.Unlock()
   483  	if scanReq, ok := ba.GetArg(roachpb.Scan); ok {
   484  		scan := scanReq.(*roachpb.ScanRequest)
   485  		if scan.Span().Overlaps(metaTableSpan) {
   486  			st.metaTableScans++
   487  		} else if scan.Span().Overlaps(recordsTableSpan) {
   488  			st.recordsTableScans++
   489  		}
   490  	}
   491  	if st.filterFunc != nil {
   492  		return st.filterFunc(ba)
   493  	}
   494  	return nil
   495  }