vitess.io/vitess@v0.16.2/go/vt/vitessdriver/driver_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package vitessdriver 18 19 import ( 20 "context" 21 "database/sql" 22 "database/sql/driver" 23 "fmt" 24 "net" 25 "os" 26 "reflect" 27 "strings" 28 "testing" 29 "time" 30 31 "github.com/stretchr/testify/assert" 32 33 "github.com/stretchr/testify/require" 34 "google.golang.org/grpc" 35 36 "vitess.io/vitess/go/sqltypes" 37 querypb "vitess.io/vitess/go/vt/proto/query" 38 "vitess.io/vitess/go/vt/vtgate/grpcvtgateservice" 39 ) 40 41 var ( 42 testAddress string 43 ) 44 45 // TestMain tests the Vitess Go SQL driver. 46 // 47 // Note that the queries used in the test are not valid SQL queries and don't 48 // have to be. The main point here is to test the interactions against a 49 // vtgate implementation (here: fakeVTGateService from fakeserver_test.go). 50 func TestMain(m *testing.M) { 51 service := CreateFakeServer() 52 53 // listen on a random port. 54 listener, err := net.Listen("tcp", "127.0.0.1:0") 55 if err != nil { 56 panic(fmt.Sprintf("Cannot listen: %v", err)) 57 } 58 59 // Create a gRPC server and listen on the port. 60 server := grpc.NewServer() 61 grpcvtgateservice.RegisterForTest(server, service) 62 go server.Serve(listener) 63 64 testAddress = listener.Addr().String() 65 os.Exit(m.Run()) 66 } 67 68 func TestOpen(t *testing.T) { 69 locationPST, err := time.LoadLocation("America/Los_Angeles") 70 if err != nil { 71 panic(err) 72 } 73 74 var testcases = []struct { 75 desc string 76 connStr string 77 conn *conn 78 }{ 79 { 80 desc: "Open()", 81 connStr: fmt.Sprintf(`{"address": "%s", "target": "@replica", "timeout": %d}`, testAddress, int64(30*time.Second)), 82 conn: &conn{ 83 Configuration: Configuration{ 84 Protocol: "grpc", 85 DriverName: "vitess", 86 Target: "@replica", 87 }, 88 convert: &converter{ 89 location: time.UTC, 90 }, 91 }, 92 }, 93 { 94 desc: "Open() (defaults omitted)", 95 connStr: fmt.Sprintf(`{"address": "%s", "timeout": %d}`, testAddress, int64(30*time.Second)), 96 conn: &conn{ 97 Configuration: Configuration{ 98 Protocol: "grpc", 99 DriverName: "vitess", 100 }, 101 convert: &converter{ 102 location: time.UTC, 103 }, 104 }, 105 }, 106 { 107 desc: "Open() with keyspace", 108 connStr: fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "target": "ks:0@replica", "timeout": %d}`, testAddress, int64(30*time.Second)), 109 conn: &conn{ 110 Configuration: Configuration{ 111 Protocol: "grpc", 112 DriverName: "vitess", 113 Target: "ks:0@replica", 114 }, 115 convert: &converter{ 116 location: time.UTC, 117 }, 118 }, 119 }, 120 { 121 desc: "Open() with custom timezone", 122 connStr: fmt.Sprintf( 123 `{"address": "%s", "timeout": %d, "defaultlocation": "America/Los_Angeles"}`, 124 testAddress, int64(30*time.Second)), 125 conn: &conn{ 126 Configuration: Configuration{ 127 Protocol: "grpc", 128 DriverName: "vitess", 129 DefaultLocation: "America/Los_Angeles", 130 }, 131 convert: &converter{ 132 location: locationPST, 133 }, 134 }, 135 }, 136 } 137 138 for _, tc := range testcases { 139 c, err := drv{}.Open(tc.connStr) 140 if err != nil { 141 t.Fatal(err) 142 } 143 defer c.Close() 144 145 wantc := tc.conn 146 newc := *(c.(*conn)) 147 newc.Address = "" 148 newc.conn = nil 149 newc.session = nil 150 if !reflect.DeepEqual(&newc, wantc) { 151 t.Errorf("%v: conn:\n%+v, want\n%+v", tc.desc, &newc, wantc) 152 } 153 } 154 } 155 156 func TestOpen_UnregisteredProtocol(t *testing.T) { 157 _, err := drv{}.Open(`{"protocol": "none"}`) 158 want := "no dialer registered for VTGate protocol none" 159 if err == nil || !strings.Contains(err.Error(), want) { 160 t.Errorf("err: %v, want %s", err, want) 161 } 162 } 163 164 func TestOpen_InvalidJson(t *testing.T) { 165 _, err := drv{}.Open(`{`) 166 want := "unexpected end of JSON input" 167 if err == nil || !strings.Contains(err.Error(), want) { 168 t.Errorf("err: %v, want %s", err, want) 169 } 170 } 171 172 func TestBeginIsolation(t *testing.T) { 173 db, err := Open(testAddress, "@primary") 174 require.NoError(t, err) 175 defer db.Close() 176 _, err = db.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) 177 want := errIsolationUnsupported.Error() 178 if err == nil || err.Error() != want { 179 t.Errorf("Begin: %v, want %s", err, want) 180 } 181 } 182 183 func TestExec(t *testing.T) { 184 db, err := Open(testAddress, "@rdonly") 185 if err != nil { 186 t.Fatal(err) 187 } 188 defer db.Close() 189 190 s, err := db.Prepare("request") 191 if err != nil { 192 t.Fatal(err) 193 } 194 defer s.Close() 195 196 r, err := s.Exec(int64(0)) 197 if err != nil { 198 t.Fatal(err) 199 } 200 if v, _ := r.LastInsertId(); v != 72 { 201 t.Errorf("insert id: %d, want 72", v) 202 } 203 if v, _ := r.RowsAffected(); v != 123 { 204 t.Errorf("rows affected: %d, want 123", v) 205 } 206 207 s2, err := db.Prepare("none") 208 if err != nil { 209 t.Fatal(err) 210 } 211 defer s2.Close() 212 213 _, err = s2.Exec() 214 want := "no match for: none" 215 if err == nil || !strings.Contains(err.Error(), want) { 216 t.Errorf("err: %v, does not contain %s", err, want) 217 } 218 } 219 220 func TestConfigurationToJSON(t *testing.T) { 221 config := Configuration{ 222 Protocol: "some-invalid-protocol", 223 Target: "ks2", 224 Streaming: true, 225 DefaultLocation: "Local", 226 } 227 want := `{"Protocol":"some-invalid-protocol","Address":"","Target":"ks2","Streaming":true,"DefaultLocation":"Local","SessionToken":""}` 228 229 json, err := config.toJSON() 230 if err != nil { 231 t.Fatal(err) 232 } 233 if json != want { 234 t.Errorf("Configuration.JSON(): got: %v want: %v", json, want) 235 } 236 } 237 238 func TestExecStreamingNotAllowed(t *testing.T) { 239 db, err := OpenForStreaming(testAddress, "@rdonly") 240 if err != nil { 241 t.Fatal(err) 242 } 243 244 s, err := db.Prepare("request") 245 if err != nil { 246 t.Fatal(err) 247 } 248 defer s.Close() 249 250 _, err = s.Exec(int64(0)) 251 want := "Exec not allowed for streaming connections" 252 if err == nil || !strings.Contains(err.Error(), want) { 253 t.Errorf("err: %v, does not contain %s", err, want) 254 } 255 } 256 257 func TestQuery(t *testing.T) { 258 var testcases = []struct { 259 desc string 260 config Configuration 261 requestName string 262 }{ 263 { 264 desc: "non-streaming, vtgate", 265 config: Configuration{ 266 Protocol: "grpc", 267 Address: testAddress, 268 Target: "@rdonly", 269 }, 270 requestName: "request", 271 }, 272 { 273 desc: "streaming, vtgate", 274 config: Configuration{ 275 Protocol: "grpc", 276 Address: testAddress, 277 Target: "@rdonly", 278 Streaming: true, 279 }, 280 requestName: "request", 281 }, 282 } 283 284 for _, tc := range testcases { 285 db, err := OpenWithConfiguration(tc.config) 286 if err != nil { 287 t.Errorf("%v: %v", tc.desc, err) 288 } 289 defer db.Close() 290 291 s, err := db.Prepare(tc.requestName) 292 if err != nil { 293 t.Errorf("%v: %v", tc.desc, err) 294 } 295 defer s.Close() 296 297 r, err := s.Query(int64(0)) 298 if err != nil { 299 t.Errorf("%v: %v", tc.desc, err) 300 } 301 defer r.Close() 302 cols, err := r.Columns() 303 if err != nil { 304 t.Errorf("%v: %v", tc.desc, err) 305 } 306 wantCols := []string{ 307 "field1", 308 "field2", 309 } 310 if !reflect.DeepEqual(cols, wantCols) { 311 t.Errorf("%v: cols: %v, want %v", tc.desc, cols, wantCols) 312 } 313 count := 0 314 wantValues := []struct { 315 field1 int16 316 field2 string 317 }{{1, "value1"}, {2, "value2"}} 318 for r.Next() { 319 var field1 int16 320 var field2 string 321 err := r.Scan(&field1, &field2) 322 if err != nil { 323 t.Errorf("%v: %v", tc.desc, err) 324 } 325 if want := wantValues[count].field1; field1 != want { 326 t.Errorf("%v: wrong value for field1: got: %v want: %v", tc.desc, field1, want) 327 } 328 if want := wantValues[count].field2; field2 != want { 329 t.Errorf("%v: wrong value for field2: got: %v want: %v", tc.desc, field2, want) 330 } 331 count++ 332 } 333 if count != len(wantValues) { 334 t.Errorf("%v: count: %d, want %d", tc.desc, count, len(wantValues)) 335 } 336 337 s2, err := db.Prepare("none") 338 if err != nil { 339 t.Errorf("%v: %v", tc.desc, err) 340 } 341 defer s2.Close() 342 343 rows, err := s2.Query() 344 want := "no match for: none" 345 if tc.config.Streaming && err == nil { 346 defer rows.Close() 347 // gRPC requires to consume the stream first before the error becomes visible. 348 if rows.Next() { 349 t.Errorf("%v: query should not have returned anything but did.", tc.desc) 350 } 351 err = rows.Err() 352 } 353 if err == nil || !strings.Contains(err.Error(), want) { 354 t.Errorf("%v: err: %v, does not contain %s", tc.desc, err, want) 355 } 356 } 357 } 358 359 func TestBindVars(t *testing.T) { 360 var testcases = []struct { 361 desc string 362 in []driver.NamedValue 363 out map[string]*querypb.BindVariable 364 outErr string 365 }{{ 366 desc: "all names", 367 in: []driver.NamedValue{{ 368 Name: "n1", 369 Value: int64(0), 370 }, { 371 Name: "n2", 372 Value: "abcd", 373 }}, 374 out: map[string]*querypb.BindVariable{ 375 "n1": sqltypes.Int64BindVariable(0), 376 "n2": sqltypes.StringBindVariable("abcd"), 377 }, 378 }, { 379 desc: "prefixed names", 380 in: []driver.NamedValue{{ 381 Name: ":n1", 382 Value: int64(0), 383 }, { 384 Name: "@n2", 385 Value: "abcd", 386 }}, 387 out: map[string]*querypb.BindVariable{ 388 "n1": sqltypes.Int64BindVariable(0), 389 "n2": sqltypes.StringBindVariable("abcd"), 390 }, 391 }, { 392 desc: "all positional", 393 in: []driver.NamedValue{{ 394 Ordinal: 1, 395 Value: int64(0), 396 }, { 397 Ordinal: 2, 398 Value: "abcd", 399 }}, 400 out: map[string]*querypb.BindVariable{ 401 "v1": sqltypes.Int64BindVariable(0), 402 "v2": sqltypes.StringBindVariable("abcd"), 403 }, 404 }, { 405 desc: "name, then position", 406 in: []driver.NamedValue{{ 407 Name: "n1", 408 Value: int64(0), 409 }, { 410 Ordinal: 2, 411 Value: "abcd", 412 }}, 413 outErr: errNoIntermixing.Error(), 414 }, { 415 desc: "position, then name", 416 in: []driver.NamedValue{{ 417 Ordinal: 1, 418 Value: int64(0), 419 }, { 420 Name: "n2", 421 Value: "abcd", 422 }}, 423 outErr: errNoIntermixing.Error(), 424 }} 425 426 converter := &converter{} 427 428 for _, tc := range testcases { 429 t.Run(tc.desc, func(t *testing.T) { 430 bv, err := converter.bindVarsFromNamedValues(tc.in) 431 if tc.outErr != "" { 432 assert.EqualError(t, err, tc.outErr) 433 } else { 434 if !reflect.DeepEqual(bv, tc.out) { 435 t.Errorf("%s: %v, want %v", tc.desc, bv, tc.out) 436 } 437 } 438 }) 439 } 440 } 441 442 func TestDatetimeQuery(t *testing.T) { 443 var testcases = []struct { 444 desc string 445 config Configuration 446 requestName string 447 }{ 448 { 449 desc: "datetime & date, vtgate", 450 config: Configuration{ 451 Protocol: "grpc", 452 Address: testAddress, 453 Target: "@rdonly", 454 }, 455 requestName: "requestDates", 456 }, 457 { 458 desc: "datetime & date (local timezone), vtgate", 459 config: Configuration{ 460 Protocol: "grpc", 461 Address: testAddress, 462 Target: "@rdonly", 463 DefaultLocation: "Local", 464 }, 465 requestName: "requestDates", 466 }, 467 { 468 desc: "datetime & date, streaming, vtgate", 469 config: Configuration{ 470 Protocol: "grpc", 471 Address: testAddress, 472 Target: "@rdonly", 473 Streaming: true, 474 }, 475 requestName: "requestDates", 476 }, 477 } 478 479 for _, tc := range testcases { 480 db, err := OpenWithConfiguration(tc.config) 481 if err != nil { 482 t.Errorf("%v: %v", tc.desc, err) 483 } 484 defer db.Close() 485 486 s, err := db.Prepare(tc.requestName) 487 if err != nil { 488 t.Errorf("%v: %v", tc.desc, err) 489 } 490 defer s.Close() 491 492 r, err := s.Query(0) 493 if err != nil { 494 t.Errorf("%v: %v", tc.desc, err) 495 } 496 defer r.Close() 497 498 cols, err := r.Columns() 499 if err != nil { 500 t.Errorf("%v: %v", tc.desc, err) 501 } 502 wantCols := []string{ 503 "fieldDatetime", 504 "fieldDate", 505 } 506 if !reflect.DeepEqual(cols, wantCols) { 507 t.Errorf("%v: cols: %v, want %v", tc.desc, cols, wantCols) 508 } 509 510 location := time.UTC 511 if tc.config.DefaultLocation != "" { 512 location, err = time.LoadLocation(tc.config.DefaultLocation) 513 if err != nil { 514 t.Errorf("%v: %v", tc.desc, err) 515 } 516 } 517 518 count := 0 519 wantValues := []struct { 520 fieldDatetime time.Time 521 fieldDate time.Time 522 }{{ 523 time.Date(2009, 3, 29, 17, 22, 11, 0, location), 524 time.Date(2006, 7, 2, 0, 0, 0, 0, location), 525 }, { 526 time.Time{}, 527 time.Time{}, 528 }} 529 530 for r.Next() { 531 var fieldDatetime time.Time 532 var fieldDate time.Time 533 err := r.Scan(&fieldDatetime, &fieldDate) 534 if err != nil { 535 t.Errorf("%v: %v", tc.desc, err) 536 } 537 if want := wantValues[count].fieldDatetime; fieldDatetime != want { 538 t.Errorf("%v: wrong value for fieldDatetime: got: %v want: %v", tc.desc, fieldDatetime, want) 539 } 540 if want := wantValues[count].fieldDate; fieldDate != want { 541 t.Errorf("%v: wrong value for fieldDate: got: %v want: %v", tc.desc, fieldDate, want) 542 } 543 count++ 544 } 545 546 if count != len(wantValues) { 547 t.Errorf("%v: count: %d, want %d", tc.desc, count, len(wantValues)) 548 } 549 } 550 } 551 552 func TestTx(t *testing.T) { 553 c := Configuration{ 554 Protocol: "grpc", 555 Address: testAddress, 556 Target: "@primary", 557 } 558 559 db, err := OpenWithConfiguration(c) 560 if err != nil { 561 t.Fatal(err) 562 } 563 defer db.Close() 564 565 tx, err := db.Begin() 566 if err != nil { 567 t.Fatal(err) 568 } 569 570 s, err := tx.Prepare("txRequest") 571 if err != nil { 572 t.Fatal(err) 573 } 574 575 _, err = s.Exec(int64(0)) 576 if err != nil { 577 t.Fatal(err) 578 } 579 err = tx.Commit() 580 if err != nil { 581 t.Fatal(err) 582 } 583 // Commit on committed transaction is caught by Golang sql package. 584 // We actually don't have to cover this in our code. 585 err = tx.Commit() 586 if err != sql.ErrTxDone { 587 t.Errorf("err: %v, not ErrTxDone", err) 588 } 589 590 // Test rollback now. 591 tx, err = db.Begin() 592 if err != nil { 593 t.Fatal(err) 594 } 595 s, err = tx.Prepare("txRequest") 596 if err != nil { 597 t.Fatal(err) 598 } 599 _, err = s.Query(int64(0)) 600 if err != nil { 601 t.Fatal(err) 602 } 603 err = tx.Rollback() 604 if err != nil { 605 t.Fatal(err) 606 } 607 // Rollback on rolled back transaction is caught by Golang sql package. 608 // We actually don't have to cover this in our code. 609 err = tx.Rollback() 610 if err != sql.ErrTxDone { 611 t.Errorf("err: %v, not ErrTxDone", err) 612 } 613 } 614 615 func TestTxExecStreamingNotAllowed(t *testing.T) { 616 db, err := OpenForStreaming(testAddress, "@rdonly") 617 if err != nil { 618 t.Fatal(err) 619 } 620 defer db.Close() 621 622 _, err = db.Begin() 623 want := "Exec not allowed for streaming connection" 624 if err == nil || !strings.Contains(err.Error(), want) { 625 t.Errorf("err: %v, does not contain %s", err, want) 626 } 627 } 628 629 func TestSessionToken(t *testing.T) { 630 c := Configuration{ 631 Protocol: "grpc", 632 Address: testAddress, 633 Target: "@primary", 634 } 635 636 ctx := context.Background() 637 638 db, err := OpenWithConfiguration(c) 639 if err != nil { 640 t.Fatal(err) 641 } 642 defer db.Close() 643 644 tx, err := db.Begin() 645 if err != nil { 646 t.Fatal(err) 647 } 648 649 s, err := tx.Prepare("txRequest") 650 if err != nil { 651 t.Fatal(err) 652 } 653 654 _, err = s.Exec(int64(0)) 655 if err != nil { 656 t.Fatal(err) 657 } 658 659 sessionToken, err := SessionTokenFromTx(ctx, tx) 660 if err != nil { 661 t.Fatal(err) 662 } 663 664 distributedTxConfig := Configuration{ 665 Address: testAddress, 666 Target: "@primary", 667 SessionToken: sessionToken, 668 } 669 670 sameTx, sameValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) 671 if err != nil { 672 t.Fatal(err) 673 } 674 675 newS, err := sameTx.Prepare("distributedTxRequest") 676 if err != nil { 677 t.Fatal(err) 678 } 679 680 _, err = newS.Exec(int64(1)) 681 if err != nil { 682 t.Fatal(err) 683 } 684 685 err = sameValidationFunc() 686 if err != nil { 687 t.Fatal(err) 688 } 689 690 // enforce that Rollback can't be called on the distributed tx 691 noRollbackTx, noRollbackValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) 692 if err != nil { 693 t.Fatal(err) 694 } 695 696 err = noRollbackValidationFunc() 697 if err != nil { 698 t.Fatal(err) 699 } 700 701 err = noRollbackTx.Rollback() 702 if err == nil || err.Error() != "calling Rollback from a distributed tx is not allowed" { 703 t.Fatal(err) 704 } 705 706 // enforce that Commit can't be called on the distributed tx 707 noCommitTx, noCommitValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) 708 if err != nil { 709 t.Fatal(err) 710 } 711 712 err = noCommitValidationFunc() 713 if err != nil { 714 t.Fatal(err) 715 } 716 717 err = noCommitTx.Commit() 718 if err == nil || err.Error() != "calling Commit from a distributed tx is not allowed" { 719 t.Fatal(err) 720 } 721 722 // finally commit the original tx 723 err = tx.Commit() 724 if err != nil { 725 t.Fatal(err) 726 } 727 }