vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletconntest/tabletconntest.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 // Package tabletconntest provides the test methods to make sure a 18 // tabletconn/queryservice pair over RPC works correctly. 19 package tabletconntest 20 21 import ( 22 "context" 23 "io" 24 "os" 25 "strings" 26 "testing" 27 28 "github.com/spf13/pflag" 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/require" 31 "google.golang.org/protobuf/proto" 32 33 "vitess.io/vitess/go/sqltypes" 34 "vitess.io/vitess/go/vt/callerid" 35 "vitess.io/vitess/go/vt/grpcclient" 36 "vitess.io/vitess/go/vt/log" 37 "vitess.io/vitess/go/vt/servenv" 38 "vitess.io/vitess/go/vt/vterrors" 39 "vitess.io/vitess/go/vt/vttablet/queryservice" 40 "vitess.io/vitess/go/vt/vttablet/tabletconn" 41 42 querypb "vitess.io/vitess/go/vt/proto/query" 43 topodatapb "vitess.io/vitess/go/vt/proto/topodata" 44 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 45 ) 46 47 // testErrorHelper will check one instance of each error type, 48 // to make sure we propagate the errors properly. 49 func testErrorHelper(t *testing.T, f *FakeQueryService, name string, ef func(context.Context) error) { 50 errors := []error{ 51 // A few generic errors 52 vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "generic error"), 53 vterrors.Errorf(vtrpcpb.Code_UNKNOWN, "uncaught panic"), 54 vterrors.Errorf(vtrpcpb.Code_UNAUTHENTICATED, "missing caller id"), 55 vterrors.Errorf(vtrpcpb.Code_PERMISSION_DENIED, "table acl error: nil acl"), 56 57 // Client will retry on this specific error 58 vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "query disallowed due to rule: %v", "cool rule"), 59 60 // Client may retry on another server on this specific error 61 vterrors.Errorf(vtrpcpb.Code_INTERNAL, "could not verify strict mode"), 62 63 // This is usually transaction pool full 64 vterrors.Errorf(vtrpcpb.Code_RESOURCE_EXHAUSTED, "transaction pool connection limit exceeded"), 65 66 // Transaction expired or was unknown 67 vterrors.Errorf(vtrpcpb.Code_ABORTED, "transaction 12"), 68 } 69 for _, e := range errors { 70 f.TabletError = e 71 ctx := context.Background() 72 err := ef(ctx) 73 if err == nil { 74 t.Errorf("error wasn't returned for %v?", name) 75 continue 76 } 77 78 // First we check the recoverable vtrpc code is right. 79 code := vterrors.Code(err) 80 wantcode := vterrors.Code(e) 81 if code != wantcode { 82 t.Errorf("unexpected server code from %v: got %v, wanted %v", name, code, wantcode) 83 } 84 85 if !strings.Contains(err.Error(), e.Error()) { 86 t.Errorf("client error message '%v' for %v doesn't contain expected server text message '%v'", err.Error(), name, e) 87 } 88 } 89 f.TabletError = nil 90 } 91 92 func testPanicHelper(t *testing.T, f *FakeQueryService, name string, pf func(context.Context) error) { 93 f.Panics = true 94 ctx := context.Background() 95 if err := pf(ctx); err == nil || !strings.Contains(err.Error(), "caught test panic") { 96 t.Fatalf("unexpected panic error for %v: %v", name, err) 97 } 98 f.Panics = false 99 } 100 101 func testBegin(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 102 t.Log("testBegin") 103 ctx := context.Background() 104 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 105 state, err := conn.Begin(ctx, TestTarget, TestExecuteOptions) 106 if err != nil { 107 t.Fatalf("Begin failed: %v", err) 108 } 109 if state.TransactionID != beginTransactionID { 110 t.Errorf("Unexpected result from Begin: got %v wanted %v", state.TransactionID, beginTransactionID) 111 } 112 assert.Equal(t, TestAlias, state.TabletAlias, "Unexpected tablet alias from Begin") 113 } 114 115 func testBeginError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 116 t.Log("testBeginError") 117 f.HasBeginError = true 118 testErrorHelper(t, f, "Begin", func(ctx context.Context) error { 119 _, err := conn.Begin(ctx, TestTarget, nil) 120 return err 121 }) 122 f.HasBeginError = false 123 } 124 125 func testBeginPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 126 t.Log("testBeginPanics") 127 testPanicHelper(t, f, "Begin", func(ctx context.Context) error { 128 _, err := conn.Begin(ctx, TestTarget, nil) 129 return err 130 }) 131 } 132 133 func testCommit(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 134 t.Log("testCommit") 135 ctx := context.Background() 136 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 137 _, err := conn.Commit(ctx, TestTarget, commitTransactionID) 138 if err != nil { 139 t.Fatalf("Commit failed: %v", err) 140 } 141 } 142 143 func testCommitError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 144 t.Log("testCommitError") 145 f.HasError = true 146 testErrorHelper(t, f, "Commit", func(ctx context.Context) error { 147 _, err := conn.Commit(ctx, TestTarget, commitTransactionID) 148 return err 149 }) 150 f.HasError = false 151 } 152 153 func testCommitPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 154 t.Log("testCommitPanics") 155 testPanicHelper(t, f, "Commit", func(ctx context.Context) error { 156 _, err := conn.Commit(ctx, TestTarget, commitTransactionID) 157 return err 158 }) 159 } 160 161 func testRollback(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 162 t.Log("testRollback") 163 ctx := context.Background() 164 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 165 _, err := conn.Rollback(ctx, TestTarget, rollbackTransactionID) 166 if err != nil { 167 t.Fatalf("Rollback failed: %v", err) 168 } 169 } 170 171 func testRollbackError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 172 t.Log("testRollbackError") 173 f.HasError = true 174 testErrorHelper(t, f, "Rollback", func(ctx context.Context) error { 175 _, err := conn.Rollback(ctx, TestTarget, commitTransactionID) 176 return err 177 }) 178 f.HasError = false 179 } 180 181 func testRollbackPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 182 t.Log("testRollbackPanics") 183 testPanicHelper(t, f, "Rollback", func(ctx context.Context) error { 184 _, err := conn.Rollback(ctx, TestTarget, rollbackTransactionID) 185 return err 186 }) 187 } 188 189 func testPrepare(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 190 t.Log("testPrepare") 191 ctx := context.Background() 192 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 193 err := conn.Prepare(ctx, TestTarget, commitTransactionID, Dtid) 194 if err != nil { 195 t.Fatalf("Prepare failed: %v", err) 196 } 197 } 198 199 func testPrepareError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 200 t.Log("testPrepareError") 201 f.HasError = true 202 testErrorHelper(t, f, "Prepare", func(ctx context.Context) error { 203 return conn.Prepare(ctx, TestTarget, commitTransactionID, Dtid) 204 }) 205 f.HasError = false 206 } 207 208 func testPreparePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 209 t.Log("testPreparePanics") 210 testPanicHelper(t, f, "Prepare", func(ctx context.Context) error { 211 return conn.Prepare(ctx, TestTarget, commitTransactionID, Dtid) 212 }) 213 } 214 215 func testCommitPrepared(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 216 t.Log("testCommitPrepared") 217 ctx := context.Background() 218 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 219 err := conn.CommitPrepared(ctx, TestTarget, Dtid) 220 if err != nil { 221 t.Fatalf("CommitPrepared failed: %v", err) 222 } 223 } 224 225 func testCommitPreparedError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 226 t.Log("testCommitPreparedError") 227 f.HasError = true 228 testErrorHelper(t, f, "CommitPrepared", func(ctx context.Context) error { 229 return conn.CommitPrepared(ctx, TestTarget, Dtid) 230 }) 231 f.HasError = false 232 } 233 234 func testCommitPreparedPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 235 t.Log("testCommitPreparedPanics") 236 testPanicHelper(t, f, "CommitPrepared", func(ctx context.Context) error { 237 return conn.CommitPrepared(ctx, TestTarget, Dtid) 238 }) 239 } 240 241 func testRollbackPrepared(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 242 t.Log("testRollbackPrepared") 243 ctx := context.Background() 244 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 245 err := conn.RollbackPrepared(ctx, TestTarget, Dtid, rollbackTransactionID) 246 if err != nil { 247 t.Fatalf("RollbackPrepared failed: %v", err) 248 } 249 } 250 251 func testRollbackPreparedError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 252 t.Log("testRollbackPreparedError") 253 f.HasError = true 254 testErrorHelper(t, f, "RollbackPrepared", func(ctx context.Context) error { 255 return conn.RollbackPrepared(ctx, TestTarget, Dtid, rollbackTransactionID) 256 }) 257 f.HasError = false 258 } 259 260 func testRollbackPreparedPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 261 t.Log("testRollbackPreparedPanics") 262 testPanicHelper(t, f, "RollbackPrepared", func(ctx context.Context) error { 263 return conn.RollbackPrepared(ctx, TestTarget, Dtid, rollbackTransactionID) 264 }) 265 } 266 267 func testCreateTransaction(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 268 t.Log("testCreateTransaction") 269 ctx := context.Background() 270 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 271 err := conn.CreateTransaction(ctx, TestTarget, Dtid, Participants) 272 if err != nil { 273 t.Fatalf("CreateTransaction failed: %v", err) 274 } 275 } 276 277 func testCreateTransactionError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 278 t.Log("testCreateTransactionError") 279 f.HasError = true 280 testErrorHelper(t, f, "CreateTransaction", func(ctx context.Context) error { 281 return conn.CreateTransaction(ctx, TestTarget, Dtid, Participants) 282 }) 283 f.HasError = false 284 } 285 286 func testCreateTransactionPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 287 t.Log("testCreateTransactionPanics") 288 testPanicHelper(t, f, "CreateTransaction", func(ctx context.Context) error { 289 return conn.CreateTransaction(ctx, TestTarget, Dtid, Participants) 290 }) 291 } 292 293 func testStartCommit(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 294 t.Log("testStartCommit") 295 ctx := context.Background() 296 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 297 err := conn.StartCommit(ctx, TestTarget, commitTransactionID, Dtid) 298 if err != nil { 299 t.Fatalf("StartCommit failed: %v", err) 300 } 301 } 302 303 func testStartCommitError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 304 t.Log("testStartCommitError") 305 f.HasError = true 306 testErrorHelper(t, f, "StartCommit", func(ctx context.Context) error { 307 return conn.StartCommit(ctx, TestTarget, commitTransactionID, Dtid) 308 }) 309 f.HasError = false 310 } 311 312 func testStartCommitPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 313 t.Log("testStartCommitPanics") 314 testPanicHelper(t, f, "StartCommit", func(ctx context.Context) error { 315 return conn.StartCommit(ctx, TestTarget, commitTransactionID, Dtid) 316 }) 317 } 318 319 func testSetRollback(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 320 t.Log("testSetRollback") 321 ctx := context.Background() 322 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 323 err := conn.SetRollback(ctx, TestTarget, Dtid, rollbackTransactionID) 324 if err != nil { 325 t.Fatalf("SetRollback failed: %v", err) 326 } 327 } 328 329 func testSetRollbackError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 330 t.Log("testSetRollbackError") 331 f.HasError = true 332 testErrorHelper(t, f, "SetRollback", func(ctx context.Context) error { 333 return conn.SetRollback(ctx, TestTarget, Dtid, rollbackTransactionID) 334 }) 335 f.HasError = false 336 } 337 338 func testSetRollbackPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 339 t.Log("testSetRollbackPanics") 340 testPanicHelper(t, f, "SetRollback", func(ctx context.Context) error { 341 return conn.SetRollback(ctx, TestTarget, Dtid, rollbackTransactionID) 342 }) 343 } 344 345 func testConcludeTransaction(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 346 t.Log("testConcludeTransaction") 347 ctx := context.Background() 348 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 349 err := conn.ConcludeTransaction(ctx, TestTarget, Dtid) 350 if err != nil { 351 t.Fatalf("ConcludeTransaction failed: %v", err) 352 } 353 } 354 355 func testConcludeTransactionError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 356 t.Log("testConcludeTransactionError") 357 f.HasError = true 358 testErrorHelper(t, f, "ConcludeTransaction", func(ctx context.Context) error { 359 return conn.ConcludeTransaction(ctx, TestTarget, Dtid) 360 }) 361 f.HasError = false 362 } 363 364 func testConcludeTransactionPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 365 t.Log("testConcludeTransactionPanics") 366 testPanicHelper(t, f, "ConcludeTransaction", func(ctx context.Context) error { 367 return conn.ConcludeTransaction(ctx, TestTarget, Dtid) 368 }) 369 } 370 371 func testReadTransaction(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 372 t.Log("testReadTransaction") 373 ctx := context.Background() 374 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 375 metadata, err := conn.ReadTransaction(ctx, TestTarget, Dtid) 376 if err != nil { 377 t.Fatalf("ReadTransaction failed: %v", err) 378 } 379 if !proto.Equal(metadata, Metadata) { 380 t.Errorf("Unexpected result from Execute: got %v wanted %v", metadata, Metadata) 381 } 382 } 383 384 func testReadTransactionError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 385 t.Log("testReadTransactionError") 386 f.HasError = true 387 testErrorHelper(t, f, "ReadTransaction", func(ctx context.Context) error { 388 _, err := conn.ReadTransaction(ctx, TestTarget, Dtid) 389 return err 390 }) 391 f.HasError = false 392 } 393 394 func testReadTransactionPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 395 t.Log("testReadTransactionPanics") 396 testPanicHelper(t, f, "ReadTransaction", func(ctx context.Context) error { 397 _, err := conn.ReadTransaction(ctx, TestTarget, Dtid) 398 return err 399 }) 400 } 401 402 func testExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 403 t.Log("testExecute") 404 f.ExpectedTransactionID = ExecuteTransactionID 405 ctx := context.Background() 406 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 407 qr, err := conn.Execute(ctx, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID, TestExecuteOptions) 408 if err != nil { 409 t.Fatalf("Execute failed: %v", err) 410 } 411 if !qr.Equal(&ExecuteQueryResult) { 412 t.Errorf("Unexpected result from Execute: got %v wanted %v", qr, ExecuteQueryResult) 413 } 414 } 415 416 func testExecuteError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 417 t.Log("testExecuteError") 418 f.HasError = true 419 testErrorHelper(t, f, "Execute", func(ctx context.Context) error { 420 _, err := conn.Execute(ctx, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID, TestExecuteOptions) 421 return err 422 }) 423 f.HasError = false 424 } 425 426 func testExecutePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 427 t.Log("testExecutePanics") 428 testPanicHelper(t, f, "Execute", func(ctx context.Context) error { 429 _, err := conn.Execute(ctx, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID, TestExecuteOptions) 430 return err 431 }) 432 } 433 434 func testBeginExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 435 t.Log("testBeginExecute") 436 f.ExpectedTransactionID = beginTransactionID 437 ctx := context.Background() 438 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 439 state, qr, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions) 440 if err != nil { 441 t.Fatalf("BeginExecute failed: %v", err) 442 } 443 if state.TransactionID != beginTransactionID { 444 t.Errorf("Unexpected result from BeginExecute: got %v wanted %v", state.TransactionID, beginTransactionID) 445 } 446 if !qr.Equal(&ExecuteQueryResult) { 447 t.Errorf("Unexpected result from BeginExecute: got %v wanted %v", qr, ExecuteQueryResult) 448 } 449 assert.Equal(t, TestAlias, state.TabletAlias, "Unexpected tablet alias from Begin") 450 } 451 452 func testBeginExecuteErrorInBegin(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 453 t.Log("testBeginExecuteErrorInBegin") 454 f.HasBeginError = true 455 testErrorHelper(t, f, "BeginExecute.Begin", func(ctx context.Context) error { 456 state, _, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions) 457 if state.TransactionID != 0 { 458 t.Errorf("Unexpected transactionID from BeginExecute: got %v wanted 0", state.TransactionID) 459 } 460 return err 461 }) 462 f.HasBeginError = false 463 } 464 465 func testBeginExecuteErrorInExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 466 t.Log("testBeginExecuteErrorInExecute") 467 f.HasError = true 468 testErrorHelper(t, f, "BeginExecute.Execute", func(ctx context.Context) error { 469 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 470 state, _, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions) 471 if state.TransactionID != beginTransactionID { 472 t.Errorf("Unexpected transactionID from BeginExecute: got %v wanted %v", state.TransactionID, beginTransactionID) 473 } 474 return err 475 }) 476 f.HasError = false 477 } 478 479 func testBeginExecutePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 480 t.Log("testBeginExecutePanics") 481 testPanicHelper(t, f, "BeginExecute", func(ctx context.Context) error { 482 _, _, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions) 483 return err 484 }) 485 } 486 487 func testStreamExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 488 t.Log("testStreamExecute") 489 ctx := context.Background() 490 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 491 i := 0 492 err := conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 493 switch i { 494 case 0: 495 if len(qr.Rows) == 0 { 496 qr.Rows = nil 497 } 498 if !qr.Equal(&StreamExecuteQueryResult1) { 499 t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1) 500 } 501 case 1: 502 if len(qr.Fields) == 0 { 503 qr.Fields = nil 504 } 505 if !qr.Equal(&StreamExecuteQueryResult2) { 506 t.Errorf("Unexpected result2 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult2) 507 } 508 default: 509 t.Fatal("callback should not be called any more") 510 } 511 i++ 512 if i >= 2 { 513 return io.EOF 514 } 515 return nil 516 }) 517 if err != nil { 518 t.Fatal(err) 519 } 520 } 521 522 func testStreamExecuteError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 523 t.Log("testStreamExecuteError") 524 f.HasError = true 525 testErrorHelper(t, f, "StreamExecute", func(ctx context.Context) error { 526 f.ErrorWait = make(chan struct{}) 527 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 528 return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 529 // For some errors, the call can be retried. 530 select { 531 case <-f.ErrorWait: 532 return nil 533 default: 534 } 535 if len(qr.Rows) == 0 { 536 qr.Rows = nil 537 } 538 if !qr.Equal(&StreamExecuteQueryResult1) { 539 t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1) 540 } 541 // signal to the server that the first result has been received 542 close(f.ErrorWait) 543 return nil 544 }) 545 }) 546 f.HasError = false 547 } 548 549 func testStreamExecutePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 550 t.Log("testStreamExecutePanics") 551 // early panic is before sending the Fields, that is returned 552 // by the StreamExecute call itself, or as the first error 553 // by ErrFunc 554 f.StreamExecutePanicsEarly = true 555 testPanicHelper(t, f, "StreamExecute.Early", func(ctx context.Context) error { 556 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 557 return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 558 return nil 559 }) 560 }) 561 562 // late panic is after sending Fields 563 f.StreamExecutePanicsEarly = false 564 testPanicHelper(t, f, "StreamExecute.Late", func(ctx context.Context) error { 565 f.PanicWait = make(chan struct{}) 566 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 567 return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 568 // For some errors, the call can be retried. 569 select { 570 case <-f.PanicWait: 571 return nil 572 default: 573 } 574 if len(qr.Rows) == 0 { 575 qr.Rows = nil 576 } 577 if !qr.Equal(&StreamExecuteQueryResult1) { 578 t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1) 579 } 580 // signal to the server that the first result has been received 581 close(f.PanicWait) 582 return nil 583 }) 584 }) 585 } 586 587 func testBeginStreamExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 588 t.Log("testBeginStreamExecute") 589 ctx := context.Background() 590 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 591 i := 0 592 _, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 593 switch i { 594 case 0: 595 if len(qr.Rows) == 0 { 596 qr.Rows = nil 597 } 598 if !qr.Equal(&StreamExecuteQueryResult1) { 599 t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1) 600 } 601 case 1: 602 if len(qr.Fields) == 0 { 603 qr.Fields = nil 604 } 605 if !qr.Equal(&StreamExecuteQueryResult2) { 606 t.Errorf("Unexpected result2 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult2) 607 } 608 default: 609 t.Fatal("callback should not be called any more") 610 } 611 i++ 612 if i >= 2 { 613 return io.EOF 614 } 615 return nil 616 }) 617 if err != nil { 618 t.Fatal(err) 619 } 620 } 621 622 func testReserveStreamExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 623 t.Log("testReserveStreamExecute") 624 ctx := context.Background() 625 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 626 i := 0 627 _, err := conn.ReserveStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 628 switch i { 629 case 0: 630 if len(qr.Rows) == 0 { 631 qr.Rows = nil 632 } 633 if !qr.Equal(&StreamExecuteQueryResult1) { 634 t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1) 635 } 636 case 1: 637 if len(qr.Fields) == 0 { 638 qr.Fields = nil 639 } 640 if !qr.Equal(&StreamExecuteQueryResult2) { 641 t.Errorf("Unexpected result2 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult2) 642 } 643 default: 644 t.Fatal("callback should not be called any more") 645 } 646 i++ 647 if i >= 2 { 648 return io.EOF 649 } 650 return nil 651 }) 652 if err != nil { 653 t.Fatal(err) 654 } 655 } 656 657 func testBeginStreamExecuteErrorInBegin(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 658 t.Log("testBeginExecuteErrorInBegin") 659 f.HasBeginError = true 660 testErrorHelper(t, f, "StreamExecute", func(ctx context.Context) error { 661 f.ErrorWait = make(chan struct{}) 662 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 663 _, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 664 // For some errors, the call can be retried. 665 select { 666 case <-f.ErrorWait: 667 return nil 668 default: 669 } 670 if len(qr.Rows) == 0 { 671 qr.Rows = nil 672 } 673 if !qr.Equal(&StreamExecuteQueryResult1) { 674 t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1) 675 } 676 // signal to the server that the first result has been received 677 close(f.ErrorWait) 678 return nil 679 }) 680 return err 681 }) 682 f.HasBeginError = false 683 } 684 685 func testBeginStreamExecuteErrorInExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 686 t.Log("testBeginStreamExecuteErrorInExecute") 687 f.HasError = true 688 testErrorHelper(t, f, "StreamExecute", func(ctx context.Context) error { 689 f.ErrorWait = make(chan struct{}) 690 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 691 state, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 692 // For some errors, the call can be retried. 693 select { 694 case <-f.ErrorWait: 695 return nil 696 default: 697 } 698 if len(qr.Rows) == 0 { 699 qr.Rows = nil 700 } 701 if !qr.Equal(&StreamExecuteQueryResult1) { 702 t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1) 703 } 704 // signal to the server that the first result has been received 705 close(f.ErrorWait) 706 return nil 707 }) 708 require.NotZero(t, state.TransactionID) 709 return err 710 }) 711 f.HasError = false 712 } 713 714 func testReserveStreamExecuteErrorInReserve(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 715 t.Log("testReserveExecuteErrorInReserve") 716 f.HasReserveError = true 717 testErrorHelper(t, f, "ReserveStreamExecute", func(ctx context.Context) error { 718 f.ErrorWait = make(chan struct{}) 719 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 720 _, err := conn.ReserveStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 721 // For some errors, the call can be retried. 722 select { 723 case <-f.ErrorWait: 724 return nil 725 default: 726 } 727 if len(qr.Rows) == 0 { 728 qr.Rows = nil 729 } 730 if !qr.Equal(&StreamExecuteQueryResult1) { 731 t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1) 732 } 733 // signal to the server that the first result has been received 734 close(f.ErrorWait) 735 return nil 736 }) 737 return err 738 }) 739 f.HasReserveError = false 740 } 741 742 func testReserveStreamExecuteErrorInExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 743 t.Log("testReserveStreamExecuteErrorInExecute") 744 f.HasError = true 745 testErrorHelper(t, f, "ReserveStreamExecute", func(ctx context.Context) error { 746 f.ErrorWait = make(chan struct{}) 747 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 748 state, err := conn.ReserveStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 749 // For some errors, the call can be retried. 750 select { 751 case <-f.ErrorWait: 752 return nil 753 default: 754 } 755 if len(qr.Rows) == 0 { 756 qr.Rows = nil 757 } 758 if !qr.Equal(&StreamExecuteQueryResult1) { 759 t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1) 760 } 761 // signal to the server that the first result has been received 762 close(f.ErrorWait) 763 return nil 764 }) 765 require.NotZero(t, state.ReservedID) 766 return err 767 }) 768 f.HasError = false 769 } 770 771 func testBeginStreamExecutePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 772 t.Log("testStreamExecutePanics") 773 // early panic is before sending the Fields, that is returned 774 // by the StreamExecute call itself, or as the first error 775 // by ErrFunc 776 f.StreamExecutePanicsEarly = true 777 testPanicHelper(t, f, "StreamExecute.Early", func(ctx context.Context) error { 778 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 779 return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 780 return nil 781 }) 782 }) 783 784 // late panic is after sending Fields 785 f.StreamExecutePanicsEarly = false 786 testPanicHelper(t, f, "StreamExecute.Late", func(ctx context.Context) error { 787 f.PanicWait = make(chan struct{}) 788 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 789 _, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { 790 // For some errors, the call can be retried. 791 select { 792 case <-f.PanicWait: 793 return nil 794 default: 795 } 796 if len(qr.Rows) == 0 { 797 qr.Rows = nil 798 } 799 if !qr.Equal(&StreamExecuteQueryResult1) { 800 t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1) 801 } 802 // signal to the server that the first result has been received 803 close(f.PanicWait) 804 return nil 805 }) 806 return err 807 }) 808 } 809 810 func testMessageStream(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 811 t.Log("testMessageStream") 812 ctx := context.Background() 813 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 814 var got *sqltypes.Result 815 err := conn.MessageStream(ctx, TestTarget, MessageName, func(qr *sqltypes.Result) error { 816 got = qr 817 return nil 818 }) 819 if err != nil { 820 t.Fatalf("MessageStream failed: %v", err) 821 } 822 if !got.Equal(MessageStreamResult) { 823 t.Errorf("Unexpected result from MessageStream: got %v wanted %v", got, MessageStreamResult) 824 } 825 } 826 827 func testMessageStreamError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 828 t.Log("testMessageStreamError") 829 f.HasError = true 830 testErrorHelper(t, f, "MessageStream", func(ctx context.Context) error { 831 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 832 return conn.MessageStream(ctx, TestTarget, MessageName, func(qr *sqltypes.Result) error { return nil }) 833 }) 834 f.HasError = false 835 } 836 837 func testMessageStreamPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 838 t.Log("testMessageStreamPanics") 839 testPanicHelper(t, f, "MessageStream", func(ctx context.Context) error { 840 err := conn.MessageStream(ctx, TestTarget, MessageName, func(qr *sqltypes.Result) error { return nil }) 841 return err 842 }) 843 } 844 845 func testMessageAck(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 846 t.Log("testMessageAck") 847 ctx := context.Background() 848 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 849 count, err := conn.MessageAck(ctx, TestTarget, MessageName, MessageIDs) 850 if err != nil { 851 t.Fatalf("MessageAck failed: %v", err) 852 } 853 if count != 1 { 854 t.Errorf("Unexpected result from MessageAck: got %v wanted 1", count) 855 } 856 } 857 858 func testMessageAckError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 859 t.Log("testMessageAckError") 860 f.HasError = true 861 testErrorHelper(t, f, "MessageAck", func(ctx context.Context) error { 862 ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) 863 _, err := conn.MessageAck(ctx, TestTarget, MessageName, MessageIDs) 864 return err 865 }) 866 f.HasError = false 867 } 868 869 func testMessageAckPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 870 t.Log("testMessageAckPanics") 871 testPanicHelper(t, f, "MessageAck", func(ctx context.Context) error { 872 _, err := conn.MessageAck(ctx, TestTarget, MessageName, MessageIDs) 873 return err 874 }) 875 } 876 877 // this test is a bit of a hack: we write something on the channel 878 // upon registration, and we also return an error, so the streaming query 879 // ends right there. Otherwise we have no real way to trigger a real 880 // communication error, that ends the streaming. 881 func testStreamHealth(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 882 t.Log("testStreamHealth") 883 ctx := context.Background() 884 885 var health *querypb.StreamHealthResponse 886 err := conn.StreamHealth(ctx, func(shr *querypb.StreamHealthResponse) error { 887 health = shr 888 return io.EOF 889 }) 890 if err != nil { 891 t.Fatalf("StreamHealth failed: %v", err) 892 } 893 if !proto.Equal(health, TestStreamHealthStreamHealthResponse) { 894 t.Errorf("invalid StreamHealthResponse: got %v expected %v", health, TestStreamHealthStreamHealthResponse) 895 } 896 } 897 898 func testStreamHealthError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 899 t.Log("testStreamHealthError") 900 f.HasError = true 901 ctx := context.Background() 902 err := conn.StreamHealth(ctx, func(shr *querypb.StreamHealthResponse) error { 903 t.Fatalf("Unexpected call to callback") 904 return nil 905 }) 906 if err == nil || !strings.Contains(err.Error(), TestStreamHealthErrorMsg) { 907 t.Fatalf("StreamHealth failed with the wrong error: %v", err) 908 } 909 f.HasError = false 910 } 911 912 func testStreamHealthPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { 913 t.Log("testStreamHealthPanics") 914 testPanicHelper(t, f, "StreamHealth", func(ctx context.Context) error { 915 return conn.StreamHealth(ctx, func(shr *querypb.StreamHealthResponse) error { 916 t.Fatalf("Unexpected call to callback") 917 return nil 918 }) 919 }) 920 } 921 922 // TestSuite runs all the tests. 923 // If fake.TestingGateway is set, we only test the calls that can go through 924 // a gateway. 925 func TestSuite(t *testing.T, protocol string, tablet *topodatapb.Tablet, fake *FakeQueryService, clientCreds *os.File) { 926 tests := []func(*testing.T, queryservice.QueryService, *FakeQueryService){ 927 // positive test cases 928 testBegin, 929 testCommit, 930 testRollback, 931 testPrepare, 932 testCommitPrepared, 933 testRollbackPrepared, 934 testCreateTransaction, 935 testStartCommit, 936 testSetRollback, 937 testConcludeTransaction, 938 testReadTransaction, 939 testExecute, 940 testBeginExecute, 941 testStreamExecute, 942 testBeginStreamExecute, 943 testMessageStream, 944 testMessageAck, 945 testReserveStreamExecute, 946 947 // error test cases 948 testBeginError, 949 testCommitError, 950 testRollbackError, 951 testPrepareError, 952 testCommitPreparedError, 953 testRollbackPreparedError, 954 testCreateTransactionError, 955 testStartCommitError, 956 testSetRollbackError, 957 testConcludeTransactionError, 958 testReadTransactionError, 959 testExecuteError, 960 testBeginExecuteErrorInBegin, 961 testBeginExecuteErrorInExecute, 962 testStreamExecuteError, 963 testBeginStreamExecuteErrorInBegin, 964 testBeginStreamExecuteErrorInExecute, 965 testReserveStreamExecuteErrorInReserve, 966 testReserveStreamExecuteErrorInExecute, 967 testMessageStreamError, 968 testMessageAckError, 969 970 // panic test cases 971 testBeginPanics, 972 testCommitPanics, 973 testRollbackPanics, 974 testPreparePanics, 975 testCommitPreparedPanics, 976 testRollbackPreparedPanics, 977 testCreateTransactionPanics, 978 testStartCommitPanics, 979 testSetRollbackPanics, 980 testConcludeTransactionPanics, 981 testReadTransactionPanics, 982 testExecutePanics, 983 testBeginExecutePanics, 984 testStreamExecutePanics, 985 testBeginStreamExecutePanics, 986 testMessageStreamPanics, 987 testMessageAckPanics, 988 } 989 990 if !fake.TestingGateway { 991 tests = append(tests, []func(*testing.T, queryservice.QueryService, *FakeQueryService){ 992 // positive test cases 993 testStreamHealth, 994 995 // error test cases 996 testStreamHealthError, 997 998 // panic test cases 999 testStreamHealthPanics, 1000 }...) 1001 } 1002 1003 // make sure we use the right client 1004 SetProtocol(t.Name(), protocol) 1005 1006 // create a connection 1007 if clientCreds != nil { 1008 fs := pflag.NewFlagSet("", pflag.ContinueOnError) 1009 grpcclient.RegisterFlags(fs) 1010 1011 err := fs.Parse([]string{ 1012 "--grpc_auth_static_client_creds", 1013 clientCreds.Name(), 1014 }) 1015 require.NoError(t, err, "failed to set `--grpc_auth_static_client_creds=%s`", clientCreds.Name()) 1016 } 1017 1018 conn, err := tabletconn.GetDialer()(tablet, grpcclient.FailFast(false)) 1019 if err != nil { 1020 t.Fatalf("dial failed: %v", err) 1021 } 1022 1023 // run the tests 1024 for _, c := range tests { 1025 c(t, conn, fake) 1026 } 1027 1028 // and we're done 1029 conn.Close(context.Background()) 1030 } 1031 1032 const tabletProtocolFlagName = "tablet_protocol" 1033 1034 // SetProtocol is a helper function to set the tabletconn --tablet_protocol flag 1035 // value for tests. 1036 // 1037 // Note that because this variable is bound to a flag, the effects of this 1038 // function are global, not scoped to the calling test-case. Therefore it should 1039 // not be used in conjunction with t.Parallel. 1040 func SetProtocol(name string, protocol string) { 1041 var tmp []string 1042 tmp, os.Args = os.Args[:], []string{name} 1043 defer func() { os.Args = tmp }() 1044 1045 servenv.OnParseFor(name, func(fs *pflag.FlagSet) { 1046 if fs.Lookup(tabletProtocolFlagName) != nil { 1047 return 1048 } 1049 1050 tabletconn.RegisterFlags(fs) 1051 }) 1052 servenv.ParseFlags(name) 1053 1054 if err := pflag.Set(tabletProtocolFlagName, protocol); err != nil { 1055 msg := "failed to set flag %q to %q: %v" 1056 log.Errorf(msg, tabletProtocolFlagName, protocol, err) 1057 } 1058 }