github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/kv/kvserver/protectedts/ptstorage/storage_test.go (about) 1 // Copyright 2019 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package ptstorage_test 12 13 import ( 14 "bytes" 15 "context" 16 "fmt" 17 "math" 18 "math/rand" 19 "regexp" 20 "sort" 21 "strconv" 22 "testing" 23 24 "github.com/cockroachdb/cockroach/pkg/base" 25 "github.com/cockroachdb/cockroach/pkg/keys" 26 "github.com/cockroachdb/cockroach/pkg/kv" 27 "github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts" 28 "github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts/ptpb" 29 "github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts/ptstorage" 30 "github.com/cockroachdb/cockroach/pkg/roachpb" 31 "github.com/cockroachdb/cockroach/pkg/security" 32 "github.com/cockroachdb/cockroach/pkg/sql" 33 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 34 "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" 35 "github.com/cockroachdb/cockroach/pkg/sql/sqlutil" 36 "github.com/cockroachdb/cockroach/pkg/testutils" 37 "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" 38 "github.com/cockroachdb/cockroach/pkg/util/hlc" 39 "github.com/cockroachdb/cockroach/pkg/util/log" 40 "github.com/cockroachdb/cockroach/pkg/util/protoutil" 41 "github.com/cockroachdb/cockroach/pkg/util/syncutil" 42 "github.com/cockroachdb/cockroach/pkg/util/uuid" 43 "github.com/cockroachdb/errors" 44 "github.com/stretchr/testify/require" 45 ) 46 47 func TestStorage(t *testing.T) { 48 for _, test := range testCases { 49 t.Run(test.name, test.run) 50 } 51 } 52 53 var testCases = []testCase{ 54 { 55 name: "Protect - simple positive", 56 ops: []op{ 57 protectOp{spans: tableSpans(42)}, 58 }, 59 }, 60 { 61 name: "Protect - no spans", 62 ops: []op{ 63 protectOp{ 64 expErr: "invalid empty set of spans", 65 }, 66 }, 67 }, 68 { 69 name: "Protect - zero timestamp", 70 ops: []op{ 71 funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) { 72 rec := newRecord(hlc.Timestamp{}, "", nil, tableSpan(42)) 73 err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 74 return tCtx.pts.Protect(ctx, txn, &rec) 75 }) 76 require.Regexp(t, "invalid zero value timestamp", err.Error()) 77 }), 78 }, 79 }, 80 { 81 name: "Protect - already verified", 82 ops: []op{ 83 funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) { 84 rec := newRecord(tCtx.tc.Server(0).Clock().Now(), "", nil, tableSpan(42)) 85 rec.Verified = true 86 err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 87 return tCtx.pts.Protect(ctx, txn, &rec) 88 }) 89 require.Regexp(t, "cannot create a verified record", err.Error()) 90 }), 91 }, 92 }, 93 { 94 name: "Protect - already exists", 95 ops: []op{ 96 protectOp{spans: tableSpans(42)}, 97 funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) { 98 rec := newRecord(tCtx.tc.Server(0).Clock().Now(), "", nil, tableSpan(42)) 99 rec.ID = pickOneRecord(tCtx) 100 err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 101 return tCtx.pts.Protect(ctx, txn, &rec) 102 }) 103 require.EqualError(t, err, protectedts.ErrExists.Error()) 104 }), 105 }, 106 }, 107 { 108 name: "Protect - too many spans", 109 ops: []op{ 110 protectOp{spans: tableSpans(42)}, 111 funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) { 112 _, err := tCtx.tc.ServerConn(0).Exec("SET CLUSTER SETTING kv.protectedts.max_spans = $1", 3) 113 require.NoError(t, err) 114 }), 115 protectOp{ 116 metaType: "asdf", 117 meta: []byte("asdf"), 118 spans: tableSpans(1, 2, 3), 119 expErr: "protectedts: limit exceeded: 1\\+3 > 3 spans", 120 }, 121 protectOp{ 122 metaType: "asdf", 123 meta: []byte("asdf"), 124 spans: tableSpans(1, 2), 125 }, 126 releaseOp{idFunc: pickOneRecord}, 127 releaseOp{idFunc: pickOneRecord}, 128 protectOp{spans: tableSpans(1)}, 129 protectOp{spans: tableSpans(2)}, 130 protectOp{spans: tableSpans(3)}, 131 protectOp{ 132 spans: tableSpans(4), 133 expErr: "protectedts: limit exceeded: 3\\+1 > 3 spans", 134 }, 135 }, 136 }, 137 { 138 name: "Protect - too many bytes", 139 ops: []op{ 140 protectOp{spans: tableSpans(42)}, 141 funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) { 142 _, err := tCtx.tc.ServerConn(0).Exec("SET CLUSTER SETTING kv.protectedts.max_bytes = $1", 1024) 143 require.NoError(t, err) 144 }), 145 protectOp{ 146 spans: append(tableSpans(1, 2), 147 func() roachpb.Span { 148 s := tableSpan(3) 149 s.EndKey = append(s.EndKey, bytes.Repeat([]byte{'a'}, 1024)...) 150 return s 151 }()), 152 expErr: "protectedts: limit exceeded: 8\\+1050 > 1024 bytes", 153 }, 154 protectOp{ 155 spans: tableSpans(1, 2), 156 }, 157 }, 158 }, 159 { 160 name: "GetRecord - does not exist", 161 ops: []op{ 162 funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) { 163 var rec *ptpb.Record 164 err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) { 165 rec, err = tCtx.pts.GetRecord(ctx, txn, randomID(tCtx)) 166 return err 167 }) 168 require.EqualError(t, err, protectedts.ErrNotExists.Error()) 169 require.Nil(t, rec) 170 }), 171 }, 172 }, 173 { 174 name: "MarkVerified", 175 ops: []op{ 176 protectOp{spans: tableSpans(42)}, 177 markVerifiedOp{idFunc: pickOneRecord}, 178 markVerifiedOp{idFunc: pickOneRecord}, // it's idempotent 179 markVerifiedOp{ 180 idFunc: randomID, 181 expErr: protectedts.ErrNotExists.Error(), 182 }, 183 }, 184 }, 185 { 186 name: "Release", 187 ops: []op{ 188 protectOp{spans: tableSpans(42)}, 189 releaseOp{idFunc: pickOneRecord}, 190 releaseOp{ 191 idFunc: randomID, 192 expErr: protectedts.ErrNotExists.Error(), 193 }, 194 }, 195 }, 196 { 197 name: "nil transaction errors", 198 ops: []op{ 199 funcOp(func(ctx context.Context, t *testing.T, tCtx *testContext) { 200 rec := newRecord(tCtx.tc.Server(0).Clock().Now(), "", nil, tableSpan(42)) 201 const msg = "must provide a non-nil transaction" 202 require.Regexp(t, msg, tCtx.pts.Protect(ctx, nil /* txn */, &rec).Error()) 203 require.Regexp(t, msg, tCtx.pts.Release(ctx, nil /* txn */, uuid.MakeV4()).Error()) 204 require.Regexp(t, msg, tCtx.pts.MarkVerified(ctx, nil /* txn */, uuid.MakeV4()).Error()) 205 _, err := tCtx.pts.GetRecord(ctx, nil /* txn */, uuid.MakeV4()) 206 require.Regexp(t, msg, err.Error()) 207 _, err = tCtx.pts.GetMetadata(ctx, nil /* txn */) 208 require.Regexp(t, msg, err.Error()) 209 _, err = tCtx.pts.GetState(ctx, nil /* txn */) 210 require.Regexp(t, msg, err.Error()) 211 }), 212 }, 213 }, 214 } 215 216 type testContext struct { 217 pts protectedts.Storage 218 tc *testcluster.TestCluster 219 db *kv.DB 220 221 state ptpb.State 222 } 223 224 type op interface { 225 run(ctx context.Context, t *testing.T, testCtx *testContext) 226 } 227 228 type funcOp func(ctx context.Context, t *testing.T, tCtx *testContext) 229 230 func (f funcOp) run(ctx context.Context, t *testing.T, tCtx *testContext) { 231 f(ctx, t, tCtx) 232 } 233 234 type releaseOp struct { 235 idFunc func(tCtx *testContext) uuid.UUID 236 expErr string 237 } 238 239 func (r releaseOp) run(ctx context.Context, t *testing.T, tCtx *testContext) { 240 id := r.idFunc(tCtx) 241 err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 242 return tCtx.pts.Release(ctx, txn, id) 243 }) 244 if !testutils.IsError(err, r.expErr) { 245 t.Fatalf("expected error to match %q, got %q", r.expErr, err) 246 } 247 if err == nil { 248 i := sort.Search(len(tCtx.state.Records), func(i int) bool { 249 return bytes.Compare(id[:], tCtx.state.Records[i].ID[:]) <= 0 250 }) 251 rec := tCtx.state.Records[i] 252 tCtx.state.Records = append(tCtx.state.Records[:i], tCtx.state.Records[i+1:]...) 253 if len(tCtx.state.Records) == 0 { 254 tCtx.state.Records = nil 255 } 256 tCtx.state.Version++ 257 tCtx.state.NumRecords-- 258 tCtx.state.NumSpans -= uint64(len(rec.Spans)) 259 encoded, err := protoutil.Marshal(&ptstorage.Spans{Spans: rec.Spans}) 260 require.NoError(t, err) 261 tCtx.state.TotalBytes -= uint64(len(encoded) + len(rec.Meta) + len(rec.MetaType)) 262 } 263 } 264 265 type markVerifiedOp struct { 266 idFunc func(tCtx *testContext) uuid.UUID 267 expErr string 268 } 269 270 func (mv markVerifiedOp) run(ctx context.Context, t *testing.T, tCtx *testContext) { 271 id := mv.idFunc(tCtx) 272 err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 273 return tCtx.pts.MarkVerified(ctx, txn, id) 274 }) 275 if !testutils.IsError(err, mv.expErr) { 276 t.Fatalf("expected error to match %q, got %q", mv.expErr, err) 277 } 278 if err == nil { 279 i := sort.Search(len(tCtx.state.Records), func(i int) bool { 280 return bytes.Compare(id[:], tCtx.state.Records[i].ID[:]) <= 0 281 }) 282 tCtx.state.Records[i].Verified = true 283 } 284 } 285 286 type protectOp struct { 287 idFunc func(*testContext) uuid.UUID 288 metaType string 289 meta []byte 290 spans []roachpb.Span 291 expErr string 292 } 293 294 func (p protectOp) run(ctx context.Context, t *testing.T, tCtx *testContext) { 295 rec := newRecord(tCtx.tc.Server(0).Clock().Now(), p.metaType, p.meta, p.spans...) 296 if p.idFunc != nil { 297 rec.ID = p.idFunc(tCtx) 298 } 299 err := tCtx.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 300 return tCtx.pts.Protect(ctx, txn, &rec) 301 }) 302 if !testutils.IsError(err, p.expErr) { 303 t.Fatalf("expected error to match %q, got %q", p.expErr, err) 304 } 305 if err == nil { 306 i := sort.Search(len(tCtx.state.Records), func(i int) bool { 307 return bytes.Compare(rec.ID[:], tCtx.state.Records[i].ID[:]) <= 0 308 }) 309 tail := tCtx.state.Records[i:] 310 tCtx.state.Records = append(tCtx.state.Records[:i:i], rec) 311 tCtx.state.Records = append(tCtx.state.Records, tail...) 312 tCtx.state.Version++ 313 tCtx.state.NumRecords++ 314 tCtx.state.NumSpans += uint64(len(rec.Spans)) 315 encoded, err := protoutil.Marshal(&ptstorage.Spans{Spans: p.spans}) 316 require.NoError(t, err) 317 tCtx.state.TotalBytes += uint64(len(encoded) + len(p.meta) + len(p.metaType)) 318 } 319 } 320 321 type testCase struct { 322 name string 323 ops []op 324 } 325 326 func (test testCase) run(t *testing.T) { 327 ctx := context.Background() 328 tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{}) 329 defer tc.Stopper().Stop(ctx) 330 331 s := tc.Server(0) 332 pts := ptstorage.New(s.ClusterSettings(), 333 s.InternalExecutor().(*sql.InternalExecutor)) 334 db := s.DB() 335 tCtx := testContext{ 336 pts: pts, 337 db: db, 338 tc: tc, 339 } 340 verify := func(t *testing.T) { 341 var state ptpb.State 342 require.NoError(t, db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) { 343 state, err = pts.GetState(ctx, txn) 344 return err 345 })) 346 var md ptpb.Metadata 347 require.NoError(t, db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) { 348 md, err = pts.GetMetadata(ctx, txn) 349 return err 350 })) 351 require.EqualValues(t, tCtx.state, state) 352 require.EqualValues(t, tCtx.state.Metadata, md) 353 for _, r := range tCtx.state.Records { 354 var rec *ptpb.Record 355 require.NoError(t, db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) { 356 rec, err = pts.GetRecord(ctx, txn, r.ID) 357 return err 358 })) 359 require.EqualValues(t, &r, rec) 360 } 361 } 362 363 for i, tOp := range test.ops { 364 if !t.Run(strconv.Itoa(i), func(t *testing.T) { 365 tOp.run(ctx, t, &tCtx) 366 verify(t) 367 }) { 368 break 369 } 370 } 371 } 372 373 func randomID(*testContext) uuid.UUID { 374 return uuid.MakeV4() 375 } 376 377 func pickOneRecord(tCtx *testContext) uuid.UUID { 378 numRecords := len(tCtx.state.Records) 379 if numRecords == 0 { 380 panic(fmt.Errorf("cannot pick one from zero records: %+v", tCtx)) 381 } 382 return tCtx.state.Records[rand.Intn(numRecords)].ID 383 } 384 385 func tableSpan(tableID uint32) roachpb.Span { 386 return roachpb.Span{ 387 Key: keys.SystemSQLCodec.TablePrefix(tableID), 388 EndKey: keys.SystemSQLCodec.TablePrefix(tableID).PrefixEnd(), 389 } 390 } 391 392 func tableSpans(tableIDs ...uint32) []roachpb.Span { 393 spans := make([]roachpb.Span, len(tableIDs)) 394 for i, tableID := range tableIDs { 395 spans[i] = tableSpan(tableID) 396 } 397 return spans 398 } 399 400 func newRecord(ts hlc.Timestamp, metaType string, meta []byte, spans ...roachpb.Span) ptpb.Record { 401 return ptpb.Record{ 402 ID: uuid.MakeV4(), 403 Timestamp: ts, 404 Mode: ptpb.PROTECT_AFTER, 405 MetaType: metaType, 406 Meta: meta, 407 Spans: spans, 408 } 409 } 410 411 // TestCorruptData exercises the handling of malformed data inside the protected 412 // timestamp tables. We don't anticipate this ever happening and it would 413 // generally be a bad thing. Nevertheless, we plan for the worst and need to 414 // understand the system behavior in that scenario. 415 // 416 // The main source of corruption in the subsystem would be malformed encoded 417 // spans. Another possible form of corruption would be that the metadata does 418 // not align with the data. The metadata misalignment will not lead to a 419 // foreground error anywhere. Corrupt spans could. 420 // 421 // A corrupt spans entry only impacts GetRecord and GetState. In both cases 422 // we omit the spans from the entry and return it, logging the error. We prefer 423 // logging the error over returning it as there's a chance that the code is 424 // merely trying to remove the malformed data. The returned Record which 425 // contains no spans will be invalid and cannot be Verified. Such a Record 426 // can be removed. 427 func TestCorruptData(t *testing.T) { 428 ctx := context.Background() 429 430 t.Run("corrupt spans", func(t *testing.T) { 431 tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{}) 432 defer tc.Stopper().Stop(ctx) 433 434 s := tc.Server(0) 435 pts := ptstorage.New(s.ClusterSettings(), 436 s.InternalExecutor().(*sql.InternalExecutor)) 437 438 rec := newRecord(s.Clock().Now(), "foo", []byte("bar"), tableSpan(42)) 439 require.NoError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 440 return pts.Protect(ctx, txn, &rec) 441 })) 442 ie := tc.Server(0).InternalExecutor().(sqlutil.InternalExecutor) 443 affected, err := ie.ExecEx( 444 ctx, "corrupt-data", nil, /* txn */ 445 sqlbase.InternalExecutorSessionDataOverride{User: security.NodeUser}, 446 "UPDATE system.protected_ts_records SET spans = $1 WHERE id = $2", 447 []byte("junk"), rec.ID.String()) 448 require.NoError(t, err) 449 require.Equal(t, 1, affected) 450 451 // Set the log scope so we can introspect the logged errors. 452 scope := log.Scope(t) 453 defer scope.Close(t) 454 455 var got *ptpb.Record 456 msg := regexp.MustCompile("failed to unmarshal spans for " + rec.ID.String() + ": ") 457 require.Regexp(t, msg, 458 s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) { 459 got, err = pts.GetRecord(ctx, txn, rec.ID) 460 return err 461 }).Error()) 462 require.Nil(t, got) 463 require.NoError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) { 464 _, err = pts.GetState(ctx, txn) 465 return err 466 })) 467 log.Flush() 468 entries, err := log.FetchEntriesFromFiles(0, math.MaxInt64, 100, msg) 469 require.NoError(t, err) 470 require.Len(t, entries, 1) 471 for _, e := range entries { 472 require.Equal(t, log.Severity_ERROR, e.Severity) 473 } 474 }) 475 t.Run("corrupt hlc timestamp", func(t *testing.T) { 476 tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{}) 477 defer tc.Stopper().Stop(ctx) 478 479 s := tc.Server(0) 480 pts := ptstorage.New(s.ClusterSettings(), 481 s.InternalExecutor().(*sql.InternalExecutor)) 482 483 rec := newRecord(s.Clock().Now(), "foo", []byte("bar"), tableSpan(42)) 484 require.NoError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 485 return pts.Protect(ctx, txn, &rec) 486 })) 487 488 // This timestamp has too many logical digits and thus will fail parsing. 489 var d tree.DDecimal 490 d.SetFinite(math.MaxInt32, -12) 491 ie := tc.Server(0).InternalExecutor().(sqlutil.InternalExecutor) 492 affected, err := ie.ExecEx( 493 ctx, "corrupt-data", nil, /* txn */ 494 sqlbase.InternalExecutorSessionDataOverride{User: security.NodeUser}, 495 "UPDATE system.protected_ts_records SET ts = $1 WHERE id = $2", 496 d.String(), rec.ID.String()) 497 require.NoError(t, err) 498 require.Equal(t, 1, affected) 499 500 // Set the log scope so we can introspect the logged errors. 501 scope := log.Scope(t) 502 defer scope.Close(t) 503 504 var got *ptpb.Record 505 msg := regexp.MustCompile("failed to parse timestamp for " + rec.ID.String() + 506 ": logical part has too many digits") 507 require.Regexp(t, msg, 508 s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) { 509 got, err = pts.GetRecord(ctx, txn, rec.ID) 510 return err 511 })) 512 require.Nil(t, got) 513 require.NoError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) (err error) { 514 _, err = pts.GetState(ctx, txn) 515 return err 516 })) 517 log.Flush() 518 519 entries, err := log.FetchEntriesFromFiles(0, math.MaxInt64, 100, msg) 520 require.NoError(t, err) 521 require.Len(t, entries, 1) 522 for _, e := range entries { 523 require.Equal(t, log.Severity_ERROR, e.Severity) 524 } 525 }) 526 } 527 528 // TestErrorsFromSQL ensures that errors from the underlying InternalExecutor 529 // are properly transmitted back to the client. 530 func TestErrorsFromSQL(t *testing.T) { 531 ctx := context.Background() 532 tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{}) 533 defer tc.Stopper().Stop(ctx) 534 535 s := tc.Server(0) 536 ie := s.InternalExecutor().(sqlutil.InternalExecutor) 537 wrappedIE := &wrappedInternalExecutor{wrapped: ie} 538 pts := ptstorage.New(s.ClusterSettings(), wrappedIE) 539 540 wrappedIE.setErrFunc(func(string) error { 541 return errors.New("boom") 542 }) 543 rec := newRecord(s.Clock().Now(), "foo", []byte("bar"), tableSpan(42)) 544 require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 545 return pts.Protect(ctx, txn, &rec) 546 }), fmt.Sprintf("failed to write record %v: boom", rec.ID)) 547 require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 548 _, err := pts.GetRecord(ctx, txn, rec.ID) 549 return err 550 }), fmt.Sprintf("failed to read record %v: boom", rec.ID)) 551 require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 552 return pts.MarkVerified(ctx, txn, rec.ID) 553 }), fmt.Sprintf("failed to mark record %v as verified: boom", rec.ID)) 554 require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 555 return pts.Release(ctx, txn, rec.ID) 556 }), fmt.Sprintf("failed to release record %v: boom", rec.ID)) 557 require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 558 _, err := pts.GetMetadata(ctx, txn) 559 return err 560 }), "failed to read metadata: boom") 561 require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 562 _, err := pts.GetState(ctx, txn) 563 return err 564 }), "failed to read metadata: boom") 565 // Test that we get an error retrieving the records in GetState. 566 // The preceding call tested the error while retriving the metadata in a 567 // call to GetState. 568 var seen bool 569 wrappedIE.setErrFunc(func(string) error { 570 if !seen { 571 seen = true 572 return nil 573 } 574 return errors.New("boom") 575 }) 576 require.EqualError(t, s.DB().Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { 577 _, err := pts.GetState(ctx, txn) 578 return err 579 }), "failed to read records: boom") 580 } 581 582 // wrappedInternalExecutor allows errors to be injected in SQL execution. 583 type wrappedInternalExecutor struct { 584 wrapped sqlutil.InternalExecutor 585 586 mu struct { 587 syncutil.RWMutex 588 errFunc func(statement string) error 589 } 590 } 591 592 var _ sqlutil.InternalExecutor = &wrappedInternalExecutor{} 593 594 func (ie *wrappedInternalExecutor) Exec( 595 ctx context.Context, opName string, txn *kv.Txn, statement string, params ...interface{}, 596 ) (int, error) { 597 panic("unimplemented") 598 } 599 600 func (ie *wrappedInternalExecutor) ExecEx( 601 ctx context.Context, 602 opName string, 603 txn *kv.Txn, 604 o sqlbase.InternalExecutorSessionDataOverride, 605 stmt string, 606 qargs ...interface{}, 607 ) (int, error) { 608 panic("unimplemented") 609 } 610 611 func (ie *wrappedInternalExecutor) QueryEx( 612 ctx context.Context, 613 opName string, 614 txn *kv.Txn, 615 session sqlbase.InternalExecutorSessionDataOverride, 616 stmt string, 617 qargs ...interface{}, 618 ) ([]tree.Datums, error) { 619 if f := ie.getErrFunc(); f != nil { 620 if err := f(stmt); err != nil { 621 return nil, err 622 } 623 } 624 return ie.wrapped.QueryEx(ctx, opName, txn, session, stmt, qargs...) 625 } 626 627 func (ie *wrappedInternalExecutor) QueryWithCols( 628 ctx context.Context, 629 opName string, 630 txn *kv.Txn, 631 o sqlbase.InternalExecutorSessionDataOverride, 632 statement string, 633 qargs ...interface{}, 634 ) ([]tree.Datums, sqlbase.ResultColumns, error) { 635 panic("unimplemented") 636 } 637 638 func (ie *wrappedInternalExecutor) QueryRowEx( 639 ctx context.Context, 640 opName string, 641 txn *kv.Txn, 642 session sqlbase.InternalExecutorSessionDataOverride, 643 stmt string, 644 qargs ...interface{}, 645 ) (tree.Datums, error) { 646 if f := ie.getErrFunc(); f != nil { 647 if err := f(stmt); err != nil { 648 return nil, err 649 } 650 } 651 return ie.wrapped.QueryRowEx(ctx, opName, txn, session, stmt, qargs...) 652 } 653 654 func (ie *wrappedInternalExecutor) Query( 655 ctx context.Context, opName string, txn *kv.Txn, statement string, params ...interface{}, 656 ) ([]tree.Datums, error) { 657 panic("not implemented") 658 } 659 660 func (ie *wrappedInternalExecutor) QueryRow( 661 ctx context.Context, opName string, txn *kv.Txn, statement string, qargs ...interface{}, 662 ) (tree.Datums, error) { 663 panic("not implemented") 664 } 665 666 func (ie *wrappedInternalExecutor) getErrFunc() func(statement string) error { 667 ie.mu.RLock() 668 defer ie.mu.RUnlock() 669 return ie.mu.errFunc 670 } 671 672 func (ie *wrappedInternalExecutor) setErrFunc(f func(statement string) error) { 673 ie.mu.Lock() 674 defer ie.mu.Unlock() 675 ie.mu.errFunc = f 676 }