github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/kv/kvserver/protectedts/ptstorage/storage_test.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package ptstorage_test
    12  
    13  import (
    14  	"bytes"
    15  	"context"
    16  	"fmt"
    17  	"math"
    18  	"math/rand"
    19  	"regexp"
    20  	"sort"
    21  	"strconv"
    22  	"testing"
    23  
    24  	"github.com/cockroachdb/cockroach/pkg/base"
    25  	"github.com/cockroachdb/cockroach/pkg/keys"
    26  	"github.com/cockroachdb/cockroach/pkg/kv"
    27  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts"
    28  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts/ptpb"
    29  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts/ptstorage"
    30  	"github.com/cockroachdb/cockroach/pkg/roachpb"
    31  	"github.com/cockroachdb/cockroach/pkg/security"
    32  	"github.com/cockroachdb/cockroach/pkg/sql"
    33  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    34  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    35  	"github.com/cockroachdb/cockroach/pkg/sql/sqlutil"
    36  	"github.com/cockroachdb/cockroach/pkg/testutils"
    37  	"github.com/cockroachdb/cockroach/pkg/testutils/testcluster"
    38  	"github.com/cockroachdb/cockroach/pkg/util/hlc"
    39  	"github.com/cockroachdb/cockroach/pkg/util/log"
    40  	"github.com/cockroachdb/cockroach/pkg/util/protoutil"
    41  	"github.com/cockroachdb/cockroach/pkg/util/syncutil"
    42  	"github.com/cockroachdb/cockroach/pkg/util/uuid"
    43  	"github.com/cockroachdb/errors"
    44  	"github.com/stretchr/testify/require"
    45  )
    46  
    47  func TestStorage(t *testing.T) {
    48  	for _, test := range testCases {
    49  		t.Run(test.name, test.run)
    50  	}
    51  }
    52  
    53  var testCases = []testCase{
    54  	{
    55  		name: "Protect - simple positive",
    56  		ops: []op{
    57  			protectOp{spans: tableSpans(42)},
    58  		},
    59  	},
    60  	{
    61  		name: "Protect - no spans",
    62  		ops: []op{
    63  			protectOp{
    64  				expErr: "invalid empty set of spans",
    65  			},
    66  		},
    67  	},
    68  	{
    69  		name: "Protect - zero timestamp",
    70  		ops: []op{
    71  			funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) {
    72  				rec := newRecord(hlc.Timestamp{}, "", nil, tableSpan(42))
    73  				err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
    74  					return tCtx.pts.Protect(ctx, txn, &rec)
    75  				})
    76  				require.Regexp(t, "invalid zero value timestamp", err.Error())
    77  			}),
    78  		},
    79  	},
    80  	{
    81  		name: "Protect - already verified",
    82  		ops: []op{
    83  			funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) {
    84  				rec := newRecord(tCtx.tc.Server(0).Clock().Now(), "", nil, tableSpan(42))
    85  				rec.Verified = true
    86  				err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
    87  					return tCtx.pts.Protect(ctx, txn, &rec)
    88  				})
    89  				require.Regexp(t, "cannot create a verified record", err.Error())
    90  			}),
    91  		},
    92  	},
    93  	{
    94  		name: "Protect - already exists",
    95  		ops: []op{
    96  			protectOp{spans: tableSpans(42)},
    97  			funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) {
    98  				rec := newRecord(tCtx.tc.Server(0).Clock().Now(), "", nil, tableSpan(42))
    99  				rec.ID = pickOneRecord(tCtx)
   100  				err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   101  					return tCtx.pts.Protect(ctx, txn, &rec)
   102  				})
   103  				require.EqualError(t, err, protectedts.ErrExists.Error())
   104  			}),
   105  		},
   106  	},
   107  	{
   108  		name: "Protect - too many spans",
   109  		ops: []op{
   110  			protectOp{spans: tableSpans(42)},
   111  			funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) {
   112  				_, err := tCtx.tc.ServerConn(0).Exec("SET CLUSTER SETTING kv.protectedts.max_spans = $1", 3)
   113  				require.NoError(t, err)
   114  			}),
   115  			protectOp{
   116  				metaType: "asdf",
   117  				meta:     []byte("asdf"),
   118  				spans:    tableSpans(1, 2, 3),
   119  				expErr:   "protectedts: limit exceeded: 1\\+3 > 3 spans",
   120  			},
   121  			protectOp{
   122  				metaType: "asdf",
   123  				meta:     []byte("asdf"),
   124  				spans:    tableSpans(1, 2),
   125  			},
   126  			releaseOp{idFunc: pickOneRecord},
   127  			releaseOp{idFunc: pickOneRecord},
   128  			protectOp{spans: tableSpans(1)},
   129  			protectOp{spans: tableSpans(2)},
   130  			protectOp{spans: tableSpans(3)},
   131  			protectOp{
   132  				spans:  tableSpans(4),
   133  				expErr: "protectedts: limit exceeded: 3\\+1 > 3 spans",
   134  			},
   135  		},
   136  	},
   137  	{
   138  		name: "Protect - too many bytes",
   139  		ops: []op{
   140  			protectOp{spans: tableSpans(42)},
   141  			funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) {
   142  				_, err := tCtx.tc.ServerConn(0).Exec("SET CLUSTER SETTING kv.protectedts.max_bytes = $1", 1024)
   143  				require.NoError(t, err)
   144  			}),
   145  			protectOp{
   146  				spans: append(tableSpans(1, 2),
   147  					func() roachpb.Span {
   148  						s := tableSpan(3)
   149  						s.EndKey = append(s.EndKey, bytes.Repeat([]byte{'a'}, 1024)...)
   150  						return s
   151  					}()),
   152  				expErr: "protectedts: limit exceeded: 8\\+1050 > 1024 bytes",
   153  			},
   154  			protectOp{
   155  				spans: tableSpans(1, 2),
   156  			},
   157  		},
   158  	},
   159  	{
   160  		name: "GetRecord - does not exist",
   161  		ops: []op{
   162  			funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) {
   163  				var rec *ptpb.Record
   164  				err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) {
   165  					rec, err = tCtx.pts.GetRecord(ctx, txn, randomID(tCtx))
   166  					return err
   167  				})
   168  				require.EqualError(t, err, protectedts.ErrNotExists.Error())
   169  				require.Nil(t, rec)
   170  			}),
   171  		},
   172  	},
   173  	{
   174  		name: "MarkVerified",
   175  		ops: []op{
   176  			protectOp{spans: tableSpans(42)},
   177  			markVerifiedOp{idFunc: pickOneRecord},
   178  			markVerifiedOp{idFunc: pickOneRecord}, // it's idempotent
   179  			markVerifiedOp{
   180  				idFunc: randomID,
   181  				expErr: protectedts.ErrNotExists.Error(),
   182  			},
   183  		},
   184  	},
   185  	{
   186  		name: "Release",
   187  		ops: []op{
   188  			protectOp{spans: tableSpans(42)},
   189  			releaseOp{idFunc: pickOneRecord},
   190  			releaseOp{
   191  				idFunc: randomID,
   192  				expErr: protectedts.ErrNotExists.Error(),
   193  			},
   194  		},
   195  	},
   196  	{
   197  		name: "nil transaction errors",
   198  		ops: []op{
   199  			funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) {
   200  				rec := newRecord(tCtx.tc.Server(0).Clock().Now(), "", nil, tableSpan(42))
   201  				const msg = "must provide a non-nil transaction"
   202  				require.Regexp(t, msg, tCtx.pts.Protect(ctx, nil /* txn */, &rec).Error())
   203  				require.Regexp(t, msg, tCtx.pts.Release(ctx, nil /* txn */, uuid.MakeV4()).Error())
   204  				require.Regexp(t, msg, tCtx.pts.MarkVerified(ctx, nil /* txn */, uuid.MakeV4()).Error())
   205  				_, err := tCtx.pts.GetRecord(ctx, nil /* txn */, uuid.MakeV4())
   206  				require.Regexp(t, msg, err.Error())
   207  				_, err = tCtx.pts.GetMetadata(ctx, nil /* txn */)
   208  				require.Regexp(t, msg, err.Error())
   209  				_, err = tCtx.pts.GetState(ctx, nil /* txn */)
   210  				require.Regexp(t, msg, err.Error())
   211  			}),
   212  		},
   213  	},
   214  }
   215  
   216  type testContext struct {
   217  	pts protectedts.Storage
   218  	tc  *testcluster.TestCluster
   219  	db  *kv.DB
   220  
   221  	state ptpb.State
   222  }
   223  
   224  type op interface {
   225  	run(ctx context.Context, t *testing.T, testCtx *testContext)
   226  }
   227  
   228  type funcOp func(ctx context.Context, t *testing.T, tCtx *testContext)
   229  
   230  func (f funcOp) run(ctx context.Context, t *testing.T, tCtx *testContext) {
   231  	f(ctx, t, tCtx)
   232  }
   233  
   234  type releaseOp struct {
   235  	idFunc func(tCtx *testContext) uuid.UUID
   236  	expErr string
   237  }
   238  
   239  func (r releaseOp) run(ctx context.Context, t *testing.T, tCtx *testContext) {
   240  	id := r.idFunc(tCtx)
   241  	err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   242  		return tCtx.pts.Release(ctx, txn, id)
   243  	})
   244  	if !testutils.IsError(err, r.expErr) {
   245  		t.Fatalf("expected error to match %q, got %q", r.expErr, err)
   246  	}
   247  	if err == nil {
   248  		i := sort.Search(len(tCtx.state.Records), func(i int) bool {
   249  			return bytes.Compare(id[:], tCtx.state.Records[i].ID[:]) <= 0
   250  		})
   251  		rec := tCtx.state.Records[i]
   252  		tCtx.state.Records = append(tCtx.state.Records[:i], tCtx.state.Records[i+1:]...)
   253  		if len(tCtx.state.Records) == 0 {
   254  			tCtx.state.Records = nil
   255  		}
   256  		tCtx.state.Version++
   257  		tCtx.state.NumRecords--
   258  		tCtx.state.NumSpans -= uint64(len(rec.Spans))
   259  		encoded, err := protoutil.Marshal(&ptstorage.Spans{Spans: rec.Spans})
   260  		require.NoError(t, err)
   261  		tCtx.state.TotalBytes -= uint64(len(encoded) + len(rec.Meta) + len(rec.MetaType))
   262  	}
   263  }
   264  
   265  type markVerifiedOp struct {
   266  	idFunc func(tCtx *testContext) uuid.UUID
   267  	expErr string
   268  }
   269  
   270  func (mv markVerifiedOp) run(ctx context.Context, t *testing.T, tCtx *testContext) {
   271  	id := mv.idFunc(tCtx)
   272  	err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   273  		return tCtx.pts.MarkVerified(ctx, txn, id)
   274  	})
   275  	if !testutils.IsError(err, mv.expErr) {
   276  		t.Fatalf("expected error to match %q, got %q", mv.expErr, err)
   277  	}
   278  	if err == nil {
   279  		i := sort.Search(len(tCtx.state.Records), func(i int) bool {
   280  			return bytes.Compare(id[:], tCtx.state.Records[i].ID[:]) <= 0
   281  		})
   282  		tCtx.state.Records[i].Verified = true
   283  	}
   284  }
   285  
   286  type protectOp struct {
   287  	idFunc   func(*testContext) uuid.UUID
   288  	metaType string
   289  	meta     []byte
   290  	spans    []roachpb.Span
   291  	expErr   string
   292  }
   293  
   294  func (p protectOp) run(ctx context.Context, t *testing.T, tCtx *testContext) {
   295  	rec := newRecord(tCtx.tc.Server(0).Clock().Now(), p.metaType, p.meta, p.spans...)
   296  	if p.idFunc != nil {
   297  		rec.ID = p.idFunc(tCtx)
   298  	}
   299  	err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   300  		return tCtx.pts.Protect(ctx, txn, &rec)
   301  	})
   302  	if !testutils.IsError(err, p.expErr) {
   303  		t.Fatalf("expected error to match %q, got %q", p.expErr, err)
   304  	}
   305  	if err == nil {
   306  		i := sort.Search(len(tCtx.state.Records), func(i int) bool {
   307  			return bytes.Compare(rec.ID[:], tCtx.state.Records[i].ID[:]) <= 0
   308  		})
   309  		tail := tCtx.state.Records[i:]
   310  		tCtx.state.Records = append(tCtx.state.Records[:i:i], rec)
   311  		tCtx.state.Records = append(tCtx.state.Records, tail...)
   312  		tCtx.state.Version++
   313  		tCtx.state.NumRecords++
   314  		tCtx.state.NumSpans += uint64(len(rec.Spans))
   315  		encoded, err := protoutil.Marshal(&ptstorage.Spans{Spans: p.spans})
   316  		require.NoError(t, err)
   317  		tCtx.state.TotalBytes += uint64(len(encoded) + len(p.meta) + len(p.metaType))
   318  	}
   319  }
   320  
   321  type testCase struct {
   322  	name string
   323  	ops  []op
   324  }
   325  
   326  func (test testCase) run(t *testing.T) {
   327  	ctx := context.Background()
   328  	tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{})
   329  	defer tc.Stopper().Stop(ctx)
   330  
   331  	s := tc.Server(0)
   332  	pts := ptstorage.New(s.ClusterSettings(),
   333  		s.InternalExecutor().(*sql.InternalExecutor))
   334  	db := s.DB()
   335  	tCtx := testContext{
   336  		pts: pts,
   337  		db:  db,
   338  		tc:  tc,
   339  	}
   340  	verify := func(t *testing.T) {
   341  		var state ptpb.State
   342  		require.NoError(t, db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) {
   343  			state, err = pts.GetState(ctx, txn)
   344  			return err
   345  		}))
   346  		var md ptpb.Metadata
   347  		require.NoError(t, db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) {
   348  			md, err = pts.GetMetadata(ctx, txn)
   349  			return err
   350  		}))
   351  		require.EqualValues(t, tCtx.state, state)
   352  		require.EqualValues(t, tCtx.state.Metadata, md)
   353  		for _, r := range tCtx.state.Records {
   354  			var rec *ptpb.Record
   355  			require.NoError(t, db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) {
   356  				rec, err = pts.GetRecord(ctx, txn, r.ID)
   357  				return err
   358  			}))
   359  			require.EqualValues(t, &r, rec)
   360  		}
   361  	}
   362  
   363  	for i, tOp := range test.ops {
   364  		if !t.Run(strconv.Itoa(i), func(t *testing.T) {
   365  			tOp.run(ctx, t, &tCtx)
   366  			verify(t)
   367  		}) {
   368  			break
   369  		}
   370  	}
   371  }
   372  
   373  func randomID(*testContext) uuid.UUID {
   374  	return uuid.MakeV4()
   375  }
   376  
   377  func pickOneRecord(tCtx *testContext) uuid.UUID {
   378  	numRecords := len(tCtx.state.Records)
   379  	if numRecords == 0 {
   380  		panic(fmt.Errorf("cannot pick one from zero records: %+v", tCtx))
   381  	}
   382  	return tCtx.state.Records[rand.Intn(numRecords)].ID
   383  }
   384  
   385  func tableSpan(tableID uint32) roachpb.Span {
   386  	return roachpb.Span{
   387  		Key:    keys.SystemSQLCodec.TablePrefix(tableID),
   388  		EndKey: keys.SystemSQLCodec.TablePrefix(tableID).PrefixEnd(),
   389  	}
   390  }
   391  
   392  func tableSpans(tableIDs ...uint32) []roachpb.Span {
   393  	spans := make([]roachpb.Span, len(tableIDs))
   394  	for i, tableID := range tableIDs {
   395  		spans[i] = tableSpan(tableID)
   396  	}
   397  	return spans
   398  }
   399  
   400  func newRecord(ts hlc.Timestamp, metaType string, meta []byte, spans ...roachpb.Span) ptpb.Record {
   401  	return ptpb.Record{
   402  		ID:        uuid.MakeV4(),
   403  		Timestamp: ts,
   404  		Mode:      ptpb.PROTECT_AFTER,
   405  		MetaType:  metaType,
   406  		Meta:      meta,
   407  		Spans:     spans,
   408  	}
   409  }
   410  
   411  // TestCorruptData exercises the handling of malformed data inside the protected
   412  // timestamp tables. We don't anticipate this ever happening and it would
   413  // generally be a bad thing. Nevertheless, we plan for the worst and need to
   414  // understand the system behavior in that scenario.
   415  //
   416  // The main source of corruption in the subsystem would be malformed encoded
   417  // spans. Another possible form of corruption would be that the metadata does
   418  // not align with the data. The metadata misalignment will not lead to a
   419  // foreground error anywhere. Corrupt spans could.
   420  //
   421  // A corrupt spans entry only impacts GetRecord and GetState. In both cases
   422  // we omit the spans from the entry and return it, logging the error. We prefer
   423  // logging the error over returning it as there's a chance that the code is
   424  // merely trying to remove the malformed data. The returned Record which
   425  // contains no spans will be invalid and cannot be Verified. Such a Record
   426  // can be removed.
   427  func TestCorruptData(t *testing.T) {
   428  	ctx := context.Background()
   429  
   430  	t.Run("corrupt spans", func(t *testing.T) {
   431  		tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{})
   432  		defer tc.Stopper().Stop(ctx)
   433  
   434  		s := tc.Server(0)
   435  		pts := ptstorage.New(s.ClusterSettings(),
   436  			s.InternalExecutor().(*sql.InternalExecutor))
   437  
   438  		rec := newRecord(s.Clock().Now(), "foo", []byte("bar"), tableSpan(42))
   439  		require.NoError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   440  			return pts.Protect(ctx, txn, &rec)
   441  		}))
   442  		ie := tc.Server(0).InternalExecutor().(sqlutil.InternalExecutor)
   443  		affected, err := ie.ExecEx(
   444  			ctx, "corrupt-data", nil, /* txn */
   445  			sqlbase.InternalExecutorSessionDataOverride{User: security.NodeUser},
   446  			"UPDATE system.protected_ts_records SET spans = $1 WHERE id = $2",
   447  			[]byte("junk"), rec.ID.String())
   448  		require.NoError(t, err)
   449  		require.Equal(t, 1, affected)
   450  
   451  		// Set the log scope so we can introspect the logged errors.
   452  		scope := log.Scope(t)
   453  		defer scope.Close(t)
   454  
   455  		var got *ptpb.Record
   456  		msg := regexp.MustCompile("failed to unmarshal spans for " + rec.ID.String() + ": ")
   457  		require.Regexp(t, msg,
   458  			s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) {
   459  				got, err = pts.GetRecord(ctx, txn, rec.ID)
   460  				return err
   461  			}).Error())
   462  		require.Nil(t, got)
   463  		require.NoError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) {
   464  			_, err = pts.GetState(ctx, txn)
   465  			return err
   466  		}))
   467  		log.Flush()
   468  		entries, err := log.FetchEntriesFromFiles(0, math.MaxInt64, 100, msg)
   469  		require.NoError(t, err)
   470  		require.Len(t, entries, 1)
   471  		for _, e := range entries {
   472  			require.Equal(t, log.Severity_ERROR, e.Severity)
   473  		}
   474  	})
   475  	t.Run("corrupt hlc timestamp", func(t *testing.T) {
   476  		tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{})
   477  		defer tc.Stopper().Stop(ctx)
   478  
   479  		s := tc.Server(0)
   480  		pts := ptstorage.New(s.ClusterSettings(),
   481  			s.InternalExecutor().(*sql.InternalExecutor))
   482  
   483  		rec := newRecord(s.Clock().Now(), "foo", []byte("bar"), tableSpan(42))
   484  		require.NoError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   485  			return pts.Protect(ctx, txn, &rec)
   486  		}))
   487  
   488  		// This timestamp has too many logical digits and thus will fail parsing.
   489  		var d tree.DDecimal
   490  		d.SetFinite(math.MaxInt32, -12)
   491  		ie := tc.Server(0).InternalExecutor().(sqlutil.InternalExecutor)
   492  		affected, err := ie.ExecEx(
   493  			ctx, "corrupt-data", nil, /* txn */
   494  			sqlbase.InternalExecutorSessionDataOverride{User: security.NodeUser},
   495  			"UPDATE system.protected_ts_records SET ts = $1 WHERE id = $2",
   496  			d.String(), rec.ID.String())
   497  		require.NoError(t, err)
   498  		require.Equal(t, 1, affected)
   499  
   500  		// Set the log scope so we can introspect the logged errors.
   501  		scope := log.Scope(t)
   502  		defer scope.Close(t)
   503  
   504  		var got *ptpb.Record
   505  		msg := regexp.MustCompile("failed to parse timestamp for " + rec.ID.String() +
   506  			": logical part has too many digits")
   507  		require.Regexp(t, msg,
   508  			s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) {
   509  				got, err = pts.GetRecord(ctx, txn, rec.ID)
   510  				return err
   511  			}))
   512  		require.Nil(t, got)
   513  		require.NoError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) {
   514  			_, err = pts.GetState(ctx, txn)
   515  			return err
   516  		}))
   517  		log.Flush()
   518  
   519  		entries, err := log.FetchEntriesFromFiles(0, math.MaxInt64, 100, msg)
   520  		require.NoError(t, err)
   521  		require.Len(t, entries, 1)
   522  		for _, e := range entries {
   523  			require.Equal(t, log.Severity_ERROR, e.Severity)
   524  		}
   525  	})
   526  }
   527  
   528  // TestErrorsFromSQL ensures that errors from the underlying InternalExecutor
   529  // are properly transmitted back to the client.
   530  func TestErrorsFromSQL(t *testing.T) {
   531  	ctx := context.Background()
   532  	tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{})
   533  	defer tc.Stopper().Stop(ctx)
   534  
   535  	s := tc.Server(0)
   536  	ie := s.InternalExecutor().(sqlutil.InternalExecutor)
   537  	wrappedIE := &wrappedInternalExecutor{wrapped: ie}
   538  	pts := ptstorage.New(s.ClusterSettings(), wrappedIE)
   539  
   540  	wrappedIE.setErrFunc(func(string) error {
   541  		return errors.New("boom")
   542  	})
   543  	rec := newRecord(s.Clock().Now(), "foo", []byte("bar"), tableSpan(42))
   544  	require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   545  		return pts.Protect(ctx, txn, &rec)
   546  	}), fmt.Sprintf("failed to write record %v: boom", rec.ID))
   547  	require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   548  		_, err := pts.GetRecord(ctx, txn, rec.ID)
   549  		return err
   550  	}), fmt.Sprintf("failed to read record %v: boom", rec.ID))
   551  	require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   552  		return pts.MarkVerified(ctx, txn, rec.ID)
   553  	}), fmt.Sprintf("failed to mark record %v as verified: boom", rec.ID))
   554  	require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   555  		return pts.Release(ctx, txn, rec.ID)
   556  	}), fmt.Sprintf("failed to release record %v: boom", rec.ID))
   557  	require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   558  		_, err := pts.GetMetadata(ctx, txn)
   559  		return err
   560  	}), "failed to read metadata: boom")
   561  	require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   562  		_, err := pts.GetState(ctx, txn)
   563  		return err
   564  	}), "failed to read metadata: boom")
   565  	// Test that we get an error retrieving the records in GetState.
   566  	// The preceding call tested the error while retriving the metadata in a
   567  	// call to GetState.
   568  	var seen bool
   569  	wrappedIE.setErrFunc(func(string) error {
   570  		if !seen {
   571  			seen = true
   572  			return nil
   573  		}
   574  		return errors.New("boom")
   575  	})
   576  	require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
   577  		_, err := pts.GetState(ctx, txn)
   578  		return err
   579  	}), "failed to read records: boom")
   580  }
   581  
   582  // wrappedInternalExecutor allows errors to be injected in SQL execution.
   583  type wrappedInternalExecutor struct {
   584  	wrapped sqlutil.InternalExecutor
   585  
   586  	mu struct {
   587  		syncutil.RWMutex
   588  		errFunc func(statement string) error
   589  	}
   590  }
   591  
   592  var _ sqlutil.InternalExecutor = &wrappedInternalExecutor{}
   593  
   594  func (ie *wrappedInternalExecutor) Exec(
   595  	ctx context.Context, opName string, txn *kv.Txn, statement string, params ...interface{},
   596  ) (int, error) {
   597  	panic("unimplemented")
   598  }
   599  
   600  func (ie *wrappedInternalExecutor) ExecEx(
   601  	ctx context.Context,
   602  	opName string,
   603  	txn *kv.Txn,
   604  	o sqlbase.InternalExecutorSessionDataOverride,
   605  	stmt string,
   606  	qargs ...interface{},
   607  ) (int, error) {
   608  	panic("unimplemented")
   609  }
   610  
   611  func (ie *wrappedInternalExecutor) QueryEx(
   612  	ctx context.Context,
   613  	opName string,
   614  	txn *kv.Txn,
   615  	session sqlbase.InternalExecutorSessionDataOverride,
   616  	stmt string,
   617  	qargs ...interface{},
   618  ) ([]tree.Datums, error) {
   619  	if f := ie.getErrFunc(); f != nil {
   620  		if err := f(stmt); err != nil {
   621  			return nil, err
   622  		}
   623  	}
   624  	return ie.wrapped.QueryEx(ctx, opName, txn, session, stmt, qargs...)
   625  }
   626  
   627  func (ie *wrappedInternalExecutor) QueryWithCols(
   628  	ctx context.Context,
   629  	opName string,
   630  	txn *kv.Txn,
   631  	o sqlbase.InternalExecutorSessionDataOverride,
   632  	statement string,
   633  	qargs ...interface{},
   634  ) ([]tree.Datums, sqlbase.ResultColumns, error) {
   635  	panic("unimplemented")
   636  }
   637  
   638  func (ie *wrappedInternalExecutor) QueryRowEx(
   639  	ctx context.Context,
   640  	opName string,
   641  	txn *kv.Txn,
   642  	session sqlbase.InternalExecutorSessionDataOverride,
   643  	stmt string,
   644  	qargs ...interface{},
   645  ) (tree.Datums, error) {
   646  	if f := ie.getErrFunc(); f != nil {
   647  		if err := f(stmt); err != nil {
   648  			return nil, err
   649  		}
   650  	}
   651  	return ie.wrapped.QueryRowEx(ctx, opName, txn, session, stmt, qargs...)
   652  }
   653  
   654  func (ie *wrappedInternalExecutor) Query(
   655  	ctx context.Context, opName string, txn *kv.Txn, statement string, params ...interface{},
   656  ) ([]tree.Datums, error) {
   657  	panic("not implemented")
   658  }
   659  
   660  func (ie *wrappedInternalExecutor) QueryRow(
   661  	ctx context.Context, opName string, txn *kv.Txn, statement string, qargs ...interface{},
   662  ) (tree.Datums, error) {
   663  	panic("not implemented")
   664  }
   665  
   666  func (ie *wrappedInternalExecutor) getErrFunc() func(statement string) error {
   667  	ie.mu.RLock()
   668  	defer ie.mu.RUnlock()
   669  	return ie.mu.errFunc
   670  }
   671  
   672  func (ie *wrappedInternalExecutor) setErrFunc(f func(statement string) error) {
   673  	ie.mu.Lock()
   674  	defer ie.mu.Unlock()
   675  	ie.mu.errFunc = f
   676  }