vitess.io/vitess@v0.16.2/go/vt/vtgate/grpcvtgateconn/suite_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package grpcvtgateconn 18 19 // This is agnostic of grpc and was in a separate package 'vtgateconntest'. 20 // This has been moved here for better readability. If we introduce 21 // protocols other than grpc in the future, this will have to be 22 // moved back to its own package for reusability. 23 24 import ( 25 "errors" 26 "fmt" 27 "io" 28 "strings" 29 "testing" 30 31 "google.golang.org/protobuf/proto" 32 33 "context" 34 35 "github.com/stretchr/testify/require" 36 37 "vitess.io/vitess/go/sqltypes" 38 "vitess.io/vitess/go/tb" 39 "vitess.io/vitess/go/vt/callerid" 40 "vitess.io/vitess/go/vt/vterrors" 41 "vitess.io/vitess/go/vt/vtgate/vtgateconn" 42 "vitess.io/vitess/go/vt/vtgate/vtgateservice" 43 44 binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" 45 querypb "vitess.io/vitess/go/vt/proto/query" 46 topodatapb "vitess.io/vitess/go/vt/proto/topodata" 47 vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" 48 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 49 ) 50 51 // fakeVTGateService has the server side of this fake 52 type fakeVTGateService struct { 53 t *testing.T 54 panics bool 55 hasError bool 56 57 errorWait chan struct{} 58 } 59 60 const ( 61 expectedErrMatch string = "test vtgate error" 62 expectedCode vtrpcpb.Code = vtrpcpb.Code_INVALID_ARGUMENT 63 ) 64 65 var errTestVtGateError = vterrors.New(expectedCode, expectedErrMatch) 66 67 func newContext() context.Context { 68 ctx := context.Background() 69 ctx = callerid.NewContext(ctx, testCallerID, nil) 70 return ctx 71 } 72 73 func (f *fakeVTGateService) checkCallerID(ctx context.Context, name string) { 74 ef := callerid.EffectiveCallerIDFromContext(ctx) 75 if ef == nil { 76 f.t.Errorf("no effective caller id for %v", name) 77 } else { 78 if !proto.Equal(ef, testCallerID) { 79 f.t.Errorf("invalid effective caller id for %v: got %v expected %v", name, ef, testCallerID) 80 } 81 } 82 } 83 84 // queryExecute contains all the fields we use to test Execute 85 type queryExecute struct { 86 SQL string 87 BindVariables map[string]*querypb.BindVariable 88 Session *vtgatepb.Session 89 } 90 91 func (q *queryExecute) equal(q2 *queryExecute) bool { 92 return q.SQL == q2.SQL && 93 sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) && 94 proto.Equal(q.Session, q2.Session) 95 } 96 97 // Execute is part of the VTGateService interface 98 func (f *fakeVTGateService) Execute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) { 99 if f.hasError { 100 return session, nil, errTestVtGateError 101 } 102 if f.panics { 103 panic(fmt.Errorf("test forced panic")) 104 } 105 f.checkCallerID(ctx, "Execute") 106 execCase, ok := execMap[sql] 107 if !ok { 108 return session, nil, fmt.Errorf("no match for: %s", sql) 109 } 110 query := &queryExecute{ 111 SQL: sql, 112 BindVariables: bindVariables, 113 Session: session, 114 } 115 if !query.equal(execCase.execQuery) { 116 f.t.Errorf("Execute:\n%+v, want\n%+v", query, execCase.execQuery) 117 return session, nil, nil 118 } 119 if execCase.outSession != nil { 120 proto.Reset(session) 121 proto.Merge(session, execCase.outSession) 122 } 123 return session, execCase.result, nil 124 } 125 126 // ExecuteBatch is part of the VTGateService interface 127 func (f *fakeVTGateService) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) { 128 if f.hasError { 129 return session, nil, errTestVtGateError 130 } 131 if f.panics { 132 panic(fmt.Errorf("test forced panic")) 133 } 134 f.checkCallerID(ctx, "ExecuteBatch") 135 execCase, ok := execMap[sqlList[0]] 136 if !ok { 137 return session, nil, fmt.Errorf("no match for: %s", sqlList[0]) 138 } 139 query := &queryExecute{ 140 SQL: sqlList[0], 141 BindVariables: bindVariablesList[0], 142 Session: session, 143 } 144 if !query.equal(execCase.execQuery) { 145 f.t.Errorf("Execute: %+v, want %+v", query, execCase.execQuery) 146 return session, nil, nil 147 } 148 if execCase.outSession != nil { 149 proto.Reset(session) 150 proto.Merge(session, execCase.outSession) 151 } 152 return session, []sqltypes.QueryResponse{{ 153 QueryResult: execCase.result, 154 QueryError: nil, 155 }}, nil 156 } 157 158 // StreamExecute is part of the VTGateService interface 159 func (f *fakeVTGateService) StreamExecute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { 160 if f.panics { 161 panic(fmt.Errorf("test forced panic")) 162 } 163 execCase, ok := execMap[sql] 164 if !ok { 165 return fmt.Errorf("no match for: %s", sql) 166 } 167 f.checkCallerID(ctx, "StreamExecute") 168 query := &queryExecute{ 169 SQL: sql, 170 BindVariables: bindVariables, 171 Session: session, 172 } 173 if !query.equal(execCase.execQuery) { 174 f.t.Errorf("StreamExecute:\n%+v, want\n%+v", query, execCase.execQuery) 175 return nil 176 } 177 if execCase.result != nil { 178 result := &sqltypes.Result{ 179 Fields: execCase.result.Fields, 180 } 181 if err := callback(result); err != nil { 182 return err 183 } 184 if f.hasError { 185 // wait until the client has the response, since all streaming implementation may not 186 // send previous messages if an error has been triggered. 187 <-f.errorWait 188 f.errorWait = make(chan struct{}) // for next test 189 return errTestVtGateError 190 } 191 for _, row := range execCase.result.Rows { 192 result := &sqltypes.Result{ 193 Rows: [][]sqltypes.Value{row}, 194 } 195 if err := callback(result); err != nil { 196 return err 197 } 198 } 199 } 200 return nil 201 } 202 203 // Prepare is part of the VTGateService interface 204 func (f *fakeVTGateService) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, []*querypb.Field, error) { 205 if f.hasError { 206 return session, nil, errTestVtGateError 207 } 208 if f.panics { 209 panic(fmt.Errorf("test forced panic")) 210 } 211 f.checkCallerID(ctx, "Prepare") 212 execCase, ok := execMap[sql] 213 if !ok { 214 return session, nil, fmt.Errorf("no match for: %s", sql) 215 } 216 query := &queryExecute{ 217 SQL: sql, 218 BindVariables: bindVariables, 219 Session: session, 220 } 221 if !query.equal(execCase.execQuery) { 222 f.t.Errorf("Prepare:\n%+v, want\n%+v", query, execCase.execQuery) 223 return session, nil, nil 224 } 225 if execCase.outSession != nil { 226 proto.Reset(session) 227 proto.Merge(session, execCase.outSession) 228 } 229 return session, execCase.result.Fields, nil 230 } 231 232 // CloseSession is part of the VTGateService interface 233 func (f *fakeVTGateService) CloseSession(ctx context.Context, session *vtgatepb.Session) error { 234 panic("unimplemented") 235 } 236 237 // ResolveTransaction is part of the VTGateService interface 238 func (f *fakeVTGateService) ResolveTransaction(ctx context.Context, dtid string) error { 239 if f.hasError { 240 return errTestVtGateError 241 } 242 if f.panics { 243 panic(fmt.Errorf("test forced panic")) 244 } 245 f.checkCallerID(ctx, "ResolveTransaction") 246 if dtid != dtid2 { 247 return errors.New("ResolveTransaction: dtid mismatch") 248 } 249 return nil 250 } 251 252 func (f *fakeVTGateService) VStream(ctx context.Context, tabletType topodatapb.TabletType, vgtid *binlogdatapb.VGtid, filter *binlogdatapb.Filter, flags *vtgatepb.VStreamFlags, send func([]*binlogdatapb.VEvent) error) error { 253 panic("unimplemented") 254 } 255 256 // CreateFakeServer returns the fake server for the tests 257 func CreateFakeServer(t *testing.T) vtgateservice.VTGateService { 258 return &fakeVTGateService{ 259 t: t, 260 panics: false, 261 errorWait: make(chan struct{}), 262 } 263 } 264 265 // RegisterTestDialProtocol registers a vtgateconn implementation under the "test" protocol 266 func RegisterTestDialProtocol(impl vtgateconn.Impl) { 267 vtgateconn.RegisterDialer("test", func(ctx context.Context, address string) (vtgateconn.Impl, error) { 268 return impl, nil 269 }) 270 } 271 272 // HandlePanic is part of the VTGateService interface 273 func (f *fakeVTGateService) HandlePanic(err *error) { 274 if x := recover(); x != nil { 275 // gRPC 0.13 chokes when you return a streaming error that contains newlines. 276 *err = fmt.Errorf("uncaught panic: %v, %s", x, 277 strings.Replace(string(tb.Stack(4)), "\n", ";", -1)) 278 } 279 } 280 281 // RunTests runs all the tests 282 func RunTests(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGateService) { 283 vtgateconn.RegisterDialer("test", func(ctx context.Context, address string) (vtgateconn.Impl, error) { 284 return impl, nil 285 }) 286 conn, err := vtgateconn.DialProtocol(context.Background(), "test", "") 287 if err != nil { 288 t.Fatalf("Got err: %v from vtgateconn.DialProtocol", err) 289 } 290 session := conn.Session("connection_ks@rdonly", testExecuteOptions) 291 292 fs := fakeServer.(*fakeVTGateService) 293 294 testExecute(t, session) 295 testStreamExecute(t, session) 296 testExecuteBatch(t, session) 297 testPrepare(t, session) 298 299 // force a panic at every call, then test that works 300 fs.panics = true 301 testExecutePanic(t, session) 302 testExecuteBatchPanic(t, session) 303 testStreamExecutePanic(t, session) 304 testPreparePanic(t, session) 305 fs.panics = false 306 } 307 308 // RunErrorTests runs all the tests that expect errors 309 func RunErrorTests(t *testing.T, fakeServer vtgateservice.VTGateService) { 310 conn, err := vtgateconn.DialProtocol(context.Background(), "test", "") 311 if err != nil { 312 t.Fatalf("Got err: %v from vtgateconn.DialProtocol", err) 313 } 314 session := conn.Session("connection_ks@rdonly", testExecuteOptions) 315 316 fs := fakeServer.(*fakeVTGateService) 317 318 // return an error for every call, make sure they're handled properly 319 fs.hasError = true 320 testExecuteError(t, session, fs) 321 testExecuteBatchError(t, session, fs) 322 testStreamExecuteError(t, session, fs) 323 testPrepareError(t, session, fs) 324 fs.hasError = false 325 } 326 327 func expectPanic(t *testing.T, err error) { 328 expected1 := "test forced panic" 329 expected2 := "uncaught panic" 330 if err == nil || !strings.Contains(err.Error(), expected1) || !strings.Contains(err.Error(), expected2) { 331 t.Fatalf("Expected a panic error with '%v' or '%v' but got: %v", expected1, expected2, err) 332 } 333 } 334 335 // Verifies the returned error has the properties that we expect. 336 func verifyError(t *testing.T, err error, method string) { 337 if err == nil { 338 t.Errorf("%s was expecting an error, didn't get one", method) 339 return 340 } 341 // verify error code 342 code := vterrors.Code(err) 343 if code != expectedCode { 344 t.Errorf("Unexpected error code from %s: got %v, wanted %v", method, code, expectedCode) 345 } 346 verifyErrorString(t, err, method) 347 } 348 349 func verifyErrorString(t *testing.T, err error, method string) { 350 if err == nil { 351 t.Errorf("%s was expecting an error, didn't get one", method) 352 return 353 } 354 355 if !strings.Contains(err.Error(), expectedErrMatch) { 356 t.Errorf("Unexpected error from %s: got %v, wanted err containing: %v", method, err, errTestVtGateError.Error()) 357 } 358 } 359 360 func testExecute(t *testing.T, session *vtgateconn.VTGateSession) { 361 ctx := newContext() 362 execCase := execMap["request1"] 363 qr, err := session.Execute(ctx, execCase.execQuery.SQL, execCase.execQuery.BindVariables) 364 require.NoError(t, err) 365 if !qr.Equal(execCase.result) { 366 t.Errorf("Unexpected result from Execute: got\n%#v want\n%#v", qr, execCase.result) 367 } 368 369 _, err = session.Execute(ctx, "none", nil) 370 want := "no match for: none" 371 if err == nil || !strings.Contains(err.Error(), want) { 372 t.Errorf("none request: %v, want %v", err, want) 373 } 374 } 375 376 func testExecuteError(t *testing.T, session *vtgateconn.VTGateSession, fake *fakeVTGateService) { 377 ctx := newContext() 378 execCase := execMap["errorRequst"] 379 380 _, err := session.Execute(ctx, execCase.execQuery.SQL, execCase.execQuery.BindVariables) 381 verifyError(t, err, "Execute") 382 } 383 384 func testExecutePanic(t *testing.T, session *vtgateconn.VTGateSession) { 385 ctx := newContext() 386 execCase := execMap["request1"] 387 _, err := session.Execute(ctx, execCase.execQuery.SQL, execCase.execQuery.BindVariables) 388 expectPanic(t, err) 389 } 390 391 func testExecuteBatch(t *testing.T, session *vtgateconn.VTGateSession) { 392 ctx := newContext() 393 execCase := execMap["request1"] 394 qr, err := session.ExecuteBatch(ctx, []string{execCase.execQuery.SQL}, []map[string]*querypb.BindVariable{execCase.execQuery.BindVariables}) 395 require.NoError(t, err) 396 if !qr[0].QueryResult.Equal(execCase.result) { 397 t.Errorf("Unexpected result from Execute: got\n%#v want\n%#v", qr, execCase.result) 398 } 399 400 _, err = session.ExecuteBatch(ctx, []string{"none"}, nil) 401 want := "no match for: none" 402 if err == nil || !strings.Contains(err.Error(), want) { 403 t.Errorf("none request: %v, want %v", err, want) 404 } 405 } 406 407 func testExecuteBatchError(t *testing.T, session *vtgateconn.VTGateSession, fake *fakeVTGateService) { 408 ctx := newContext() 409 execCase := execMap["errorRequst"] 410 411 _, err := session.ExecuteBatch(ctx, []string{execCase.execQuery.SQL}, []map[string]*querypb.BindVariable{execCase.execQuery.BindVariables}) 412 verifyError(t, err, "ExecuteBatch") 413 } 414 415 func testExecuteBatchPanic(t *testing.T, session *vtgateconn.VTGateSession) { 416 ctx := newContext() 417 execCase := execMap["request1"] 418 _, err := session.ExecuteBatch(ctx, []string{execCase.execQuery.SQL}, []map[string]*querypb.BindVariable{execCase.execQuery.BindVariables}) 419 expectPanic(t, err) 420 } 421 422 func testStreamExecute(t *testing.T, session *vtgateconn.VTGateSession) { 423 ctx := newContext() 424 execCase := execMap["request1"] 425 stream, err := session.StreamExecute(ctx, execCase.execQuery.SQL, execCase.execQuery.BindVariables) 426 if err != nil { 427 t.Fatal(err) 428 } 429 var qr sqltypes.Result 430 for { 431 packet, err := stream.Recv() 432 if err != nil { 433 if err != io.EOF { 434 t.Error(err) 435 } 436 break 437 } 438 if len(packet.Fields) != 0 { 439 qr.Fields = packet.Fields 440 } 441 if len(packet.Rows) != 0 { 442 qr.Rows = append(qr.Rows, packet.Rows...) 443 } 444 } 445 wantResult := *execCase.result 446 wantResult.RowsAffected = 0 447 wantResult.InsertID = 0 448 if !qr.Equal(&wantResult) { 449 t.Errorf("Unexpected result from StreamExecute: got %+v want %+v", qr, wantResult) 450 } 451 452 stream, err = session.StreamExecute(ctx, "none", nil) 453 if err != nil { 454 t.Fatal(err) 455 } 456 _, err = stream.Recv() 457 want := "no match for: none" 458 if err == nil || !strings.Contains(err.Error(), want) { 459 t.Errorf("none request: %v, want %v", err, want) 460 } 461 } 462 463 func testStreamExecuteError(t *testing.T, session *vtgateconn.VTGateSession, fake *fakeVTGateService) { 464 ctx := newContext() 465 execCase := execMap["request1"] 466 stream, err := session.StreamExecute(ctx, execCase.execQuery.SQL, execCase.execQuery.BindVariables) 467 if err != nil { 468 t.Fatalf("StreamExecute failed: %v", err) 469 } 470 qr, err := stream.Recv() 471 if err != nil { 472 t.Fatalf("StreamExecute failed: cannot read result1: %v", err) 473 } 474 475 if !qr.Equal(&streamResultFields) { 476 t.Errorf("Unexpected result from StreamExecute: got %#v want %#v", qr, &streamResultFields) 477 } 478 // signal to the server that the first result has been received 479 close(fake.errorWait) 480 // After 1 result, we expect to get an error (no more results). 481 _, err = stream.Recv() 482 if err == nil { 483 t.Fatalf("StreamExecute channel wasn't closed") 484 } 485 verifyError(t, err, "StreamExecute") 486 } 487 488 func testStreamExecutePanic(t *testing.T, session *vtgateconn.VTGateSession) { 489 ctx := newContext() 490 execCase := execMap["request1"] 491 stream, err := session.StreamExecute(ctx, execCase.execQuery.SQL, execCase.execQuery.BindVariables) 492 if err != nil { 493 t.Fatal(err) 494 } 495 _, err = stream.Recv() 496 if err == nil { 497 t.Fatalf("Received packets instead of panic?") 498 } 499 expectPanic(t, err) 500 } 501 502 func testPrepare(t *testing.T, session *vtgateconn.VTGateSession) { 503 ctx := newContext() 504 execCase := execMap["request1"] 505 _, err := session.Prepare(ctx, execCase.execQuery.SQL, execCase.execQuery.BindVariables) 506 require.NoError(t, err) 507 //if !qr.Equal(execCase.result) { 508 // t.Errorf("Unexpected result from Execute: got\n%#v want\n%#v", qr, execCase.result) 509 //} 510 511 _, err = session.Prepare(ctx, "none", nil) 512 require.EqualError(t, err, "no match for: none") 513 } 514 515 func testPrepareError(t *testing.T, session *vtgateconn.VTGateSession, fake *fakeVTGateService) { 516 ctx := newContext() 517 execCase := execMap["errorRequst"] 518 519 _, err := session.Prepare(ctx, execCase.execQuery.SQL, execCase.execQuery.BindVariables) 520 verifyError(t, err, "Prepare") 521 } 522 523 func testPreparePanic(t *testing.T, session *vtgateconn.VTGateSession) { 524 ctx := newContext() 525 execCase := execMap["request1"] 526 _, err := session.Prepare(ctx, execCase.execQuery.SQL, execCase.execQuery.BindVariables) 527 expectPanic(t, err) 528 } 529 530 var testCallerID = &vtrpcpb.CallerID{ 531 Principal: "test_principal", 532 Component: "test_component", 533 Subcomponent: "test_subcomponent", 534 } 535 536 var testExecuteOptions = &querypb.ExecuteOptions{ 537 IncludedFields: querypb.ExecuteOptions_TYPE_ONLY, 538 } 539 540 var execMap = map[string]struct { 541 execQuery *queryExecute 542 result *sqltypes.Result 543 outSession *vtgatepb.Session 544 err error 545 }{ 546 "request1": { 547 execQuery: &queryExecute{ 548 SQL: "request1", 549 BindVariables: map[string]*querypb.BindVariable{ 550 "bind1": sqltypes.Int64BindVariable(0), 551 }, 552 Session: &vtgatepb.Session{ 553 TargetString: "connection_ks@rdonly", 554 Options: testExecuteOptions, 555 Autocommit: true, 556 }, 557 }, 558 result: &result1, 559 }, 560 "errorRequst": { 561 execQuery: &queryExecute{ 562 SQL: "errorRequst", 563 BindVariables: map[string]*querypb.BindVariable{ 564 "bind1": sqltypes.Int64BindVariable(0), 565 }, 566 Session: &vtgatepb.Session{ 567 TargetString: "connection_ks@rdonly", 568 Options: testExecuteOptions, 569 }, 570 }, 571 }, 572 } 573 574 var result1 = sqltypes.Result{ 575 Fields: []*querypb.Field{ 576 { 577 Name: "field1", 578 Type: sqltypes.Int16, 579 }, 580 { 581 Name: "field2", 582 Type: sqltypes.Int32, 583 }, 584 }, 585 RowsAffected: 123, 586 InsertID: 72, 587 Rows: [][]sqltypes.Value{ 588 { 589 sqltypes.TestValue(sqltypes.Int16, "1"), 590 sqltypes.NULL, 591 }, 592 { 593 sqltypes.TestValue(sqltypes.Int16, "2"), 594 sqltypes.NewInt32(3), 595 }, 596 }, 597 } 598 599 // streamResultFields is only the fields, sent as the first packet 600 var streamResultFields = sqltypes.Result{ 601 Fields: result1.Fields, 602 Rows: [][]sqltypes.Value{}, 603 } 604 605 var dtid2 = "aa"